Skip to content

CI: 03/25/25 upstream sync #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 173 commits into from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

charleshofer and others added 30 commits March 11, 2025 14:57
This code is Mosaic specific, move it to the Mosaic directory.

PiperOrigin-RevId: 738404429
…MAOp`.

Now that we have full control over strides in the lowering, these attributes
are no longer necessary.

PiperOrigin-RevId: 738418852
This callback functionality is only used by JAX and shipped as part of its CUDA and ROCM GPU plugins. Move it into JAX, as part of a wider move of xla/python pieces that belong to JAX into JAX.

PiperOrigin-RevId: 738426489
This is a GPU-specific target.

PiperOrigin-RevId: 738441625
PiperOrigin-RevId: 738443014
…e more cleanups

PiperOrigin-RevId: 738503430
The new xla_extension_version is 320.

PiperOrigin-RevId: 738522486
superbobry and others added 28 commits March 24, 2025 13:51
Indexing is less verbose and is thus easier to read in most cases. The
functional API is really only necessary for masked loads and stores.

PiperOrigin-RevId: 740058341
…ension and load a ref into WGMMAColFragLayout format.

PiperOrigin-RevId: 740068368
PiperOrigin-RevId: 740084295
If `auditwheel show` is executed on `jax` wheel, the following message is printed:

```
INFO:auditwheel.main_show:This does not look like a platform wheel, no ELF executable or shared library file (including compiled Python C extension) found in the wheel archive
```
PiperOrigin-RevId: 740096302
… `:build_jaxlib=false`.

PiperOrigin-RevId: 740115575
Before this change, we handled attrs for initial-style primitives like jit/scan
like this:
1. the traceable would form a jaxpr and see what attrs were touched (by
   jax_getattr or jax_setattr),
2. for each such attr, the traceable would do jax_getattr to get the current
   value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new
   values out, tree-unflatten, then jax_setattr the result.

That approach would error if the function called `jax_setattr` to set a
previously non-existant attr. That is, this would work:

```python
from jax.experimental.attrs import jax_setattr
class Thing: ...
thing = Thing()
jax_setattr(thing, 'x', 1.0)
```
but it wouldn't work under a `jax.jit`.

This commit makes the same code work under a jit. We just
1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation
   are deleted, using a special sentinel value `dne_sentinel` to indicate the
   attribute initially did not exist before tracing;
2. in pjit.py's `_get_states`, when reading initial attr values before the
   pjit_p bind, if the attribute does not exist we don't try to read it and
   instead just use `dne_sentinel` as the value, which is a convenient empty
   pytree;
3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based
   on the current attr states, we map attrs that don't exist to `dne_sentinel`
   (rather than just erroring when the attr doesn't exist, as before).

In short, we use a special value to indicate "does not exist".

If `jax_getattr` supported the 'default' argument, the code would be a little
cleaner since we could avoid the `if hasattr` stuff. And that's probably a
useful feature to have anyway. We can add that in a follow-up.

This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan
etc in follow-ups.
…ult layout inference.

default_vector_size is initialized with `math.inf` and is never `None`.

PiperOrigin-RevId: 740283678
We need to include the type caster for std::string_view if we use nb::cast<std::string_view>.

PiperOrigin-RevId: 740311318
OSS Jax builds for GPU backends split `jaxlib` into three wheels and since we cannot expect a stable C++ ABI among the shared libraries, we refactor to ensure:

1. C++ objects are not created/consumed by different shared libraries.
2. Static objects are declared and defined appropriately.

This PR:

1. Migrates Jax XLA FFI callback handlers from XLA's Internal FFI API to the [External FFI API](https://github.com/openxla/xla/tree/main/xla/ffi#xla-ffi-external-vs-internal-apis). Note that we update both CPU and GPU handlers because we cannot mix Internal and External APIs.
2. Updates how FFI GPU handlers are registered, now analogous to how the original GPU custom call was registered.
3. Adds an `xla::ffi::ExecutionContext` member to `ifrt::PjRtLoadedExectuable` holding opaque pointers to callbacks.
4. Updates Jax `callback.py` to call the new FFI callback handlers.

PiperOrigin-RevId: 740327296
Also in passing fix up some header guards and authorship comments.

PiperOrigin-RevId: 740337166
…ne and WG thread semantics

PiperOrigin-RevId: 740371195
PiperOrigin-RevId: 740371651
PiperOrigin-RevId: 740379562
Introduces a new `download-jax-only-from-gcs` variable to the workflow configs. When set to 1, the test workflows will only download and install the `jax` wheel. Other artifacts such as the latest releases of `jaxlib` and the CUDA plugin dependencies will be downloaded and installed from PyPI.

PiperOrigin-RevId: 740430538
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner March 25, 2025 20:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.