Skip to content

CI: 04/14/25 upstream sync #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 723 commits into
base: rocm-main
Choose a base branch
from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

cperivol and others added 30 commits April 4, 2025 07:21
…s rather than #elems.

PiperOrigin-RevId: 743953910
…efore checking it's not too long, so that e.g. `my_1d_array[:, ...]` can be treated as a slice rather than generating a gather operation.

PiperOrigin-RevId: 743986126
We need to be careful not to destroy Python objects while using a Python 3.13- critical section to protect C++ state. The critical section might be released when calling back into Python code (much as the GIL may be released in GIL mode).

In this code Key is kept alive by the function already, but the Value may be deleted before the hash table updates are done.

PiperOrigin-RevId: 744008939
Imported from GitHub PR jax-ml#26906

Allows overriding the slice index used by XLA.

More explicit control over which slice a device ends up in is desirable:
- Various parts of the ecosystem equate slices with "devices communicating via fast interconnect". With the arrival of NVL72 we want devices managed by multiple hosts to form a single slice.
- For debugging purposes it can be useful to allow devices on the same host (managed in separate processes) to be treated as different slices. For example, [Orbax](https://github.com/google/orbax)'s local checkpointing presumes the existence of at least two slices, so overriding the boot id will allow us to test local checkpointing on a single host.

(Companion PR in XLA: openxla/xla#23347)
Copybara import of the project:

--
45aa7ce by Georg Stefan Schmid <[email protected]>:

[jax.distributed] Allow overriding XLA slice_index

Merging this change closes jax-ml#26906

COPYBARA_INTEGRATE_REVIEW=jax-ml#26906 from gspschmid:gschmid/jax-override-boot-id 45aa7ce
PiperOrigin-RevId: 744012253
to_elt must run in the parent context, while from_elt must run in the batching
context. We previously had it precisely backward!

Tests didn't catch it because our tests are extremely minimal, and in
particular didn't check a to_elt that binds primitives.
PiperOrigin-RevId: 744478350
PiperOrigin-RevId: 744480338
PiperOrigin-RevId: 744480358
PiperOrigin-RevId: 744480452
PiperOrigin-RevId: 744483310
C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a

However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand.

PiperOrigin-RevId: 744508279
…and `AbstractMesh` is not `jax.sharding.AxisType`.

PiperOrigin-RevId: 744602037
This change also fixes the transpose handling in the lowering and completely removes the use of the TransposeTransform. Instead we rely on strides. If we don't discover any issues with this, we will remove the transpose transform also from the mlir dialect.

PiperOrigin-RevId: 744618241
I removed `trivial_ctx` from the public `jax.interpreters.partial_eval`
submodule without going through a deprecation cycle, because it is highly
unlikely anyone is using it.

PiperOrigin-RevId: 744645764
…ts/outputs

This introduces version 4 of serialization, fully backwards compatible
with versions 2 and 3.

Fixes: jax-ml#24143
PiperOrigin-RevId: 744652508
PiperOrigin-RevId: 744659794
justinjfu and others added 28 commits April 11, 2025 12:13
PiperOrigin-RevId: 746546870
Use a count of chips (or omit it if 1) rather than specifying an ICI topology.

Examples:
* tpu_v5e_1x1 -> tpu_v5e
* tpu_v5e_4x2 -> tpu_v5e_x8
PiperOrigin-RevId: 746547477
PiperOrigin-RevId: 746554582
…thon as a patch, rolling back.

Reverts b1c96d4

PiperOrigin-RevId: 746565341
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users.

PiperOrigin-RevId: 746595857
This parameter is available from jax-ml#23040 and documented in https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isin.html.

PiperOrigin-RevId: 746606206
…uts as attention call.

PiperOrigin-RevId: 746616128
…king similar to Llama4.

Llama4 uses (interleaved) chunk attention to support long context.

PiperOrigin-RevId: 746661156
When we print explanations for tracing cache misses,
we use traceback_util to ignore JAX-internal functions.
Here we change the detection mechanism to use
source_info_util, which has a more exhaustive
list of JAX internals.

This removes a lot of uninteresting explanations
from a large benchmark.

jax-fixit

PiperOrigin-RevId: 746703003
We no longer have many different implicit types conforming to `Executable`, only `pxla.MeshExectuable` and `pxla.PmapExecutable`. Both are `XlaExecutable` subtypes. So define just one common base class, call it `Exectuable`, and inherit from just that in both concrete internal executable subtypes.

PiperOrigin-RevId: 746706712
These are thin and their implementations can be inlined directly at call sites in `XlaExecutable`.

Co-authored-by: Roy Frostig <[email protected]>
PiperOrigin-RevId: 746716734
PiperOrigin-RevId: 746726071
We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes.

PiperOrigin-RevId: 746735857
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner April 14, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) April 14, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.