forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
CI: 03/25/25 upstream sync #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PiperOrigin-RevId: 738315605
PiperOrigin-RevId: 738321801
PiperOrigin-RevId: 738342517
http://github.com/openxla/xla/commit/0d20d73f2c8f21c21b9f343c4363a76e980f032e. PiperOrigin-RevId: 738352930
PiperOrigin-RevId: 738376533
…ady `jit`ted. PiperOrigin-RevId: 738393973
PiperOrigin-RevId: 738398099
This code is Mosaic specific, move it to the Mosaic directory. PiperOrigin-RevId: 738404429
PiperOrigin-RevId: 738410122
…MAOp`. Now that we have full control over strides in the lowering, these attributes are no longer necessary. PiperOrigin-RevId: 738418852
…oc around PiperOrigin-RevId: 738421256
This callback functionality is only used by JAX and shipped as part of its CUDA and ROCM GPU plugins. Move it into JAX, as part of a wider move of xla/python pieces that belong to JAX into JAX. PiperOrigin-RevId: 738426489
This is a GPU-specific target. PiperOrigin-RevId: 738441625
PiperOrigin-RevId: 738443014
…xis_index` PiperOrigin-RevId: 738448436
PiperOrigin-RevId: 738454875
PiperOrigin-RevId: 738457532
PiperOrigin-RevId: 738492394
PiperOrigin-RevId: 738496186
PiperOrigin-RevId: 738498184
…e more cleanups PiperOrigin-RevId: 738503430
The new xla_extension_version is 320. PiperOrigin-RevId: 738522486
Indexing is less verbose and is thus easier to read in most cases. The functional API is really only necessary for masked loads and stores. PiperOrigin-RevId: 740058341
…ension and load a ref into WGMMAColFragLayout format. PiperOrigin-RevId: 740068368
PiperOrigin-RevId: 740084295
If `auditwheel show` is executed on `jax` wheel, the following message is printed: ``` INFO:auditwheel.main_show:This does not look like a platform wheel, no ELF executable or shared library file (including compiled Python C extension) found in the wheel archive ``` PiperOrigin-RevId: 740096302
…jit APIs. Fixes jax-ml#27390 PiperOrigin-RevId: 740104692
… `:build_jaxlib=false`. PiperOrigin-RevId: 740115575
PiperOrigin-RevId: 740142231
Before this change, we handled attrs for initial-style primitives like jit/scan like this: 1. the traceable would form a jaxpr and see what attrs were touched (by jax_getattr or jax_setattr), 2. for each such attr, the traceable would do jax_getattr to get the current value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new values out, tree-unflatten, then jax_setattr the result. That approach would error if the function called `jax_setattr` to set a previously non-existant attr. That is, this would work: ```python from jax.experimental.attrs import jax_setattr class Thing: ... thing = Thing() jax_setattr(thing, 'x', 1.0) ``` but it wouldn't work under a `jax.jit`. This commit makes the same code work under a jit. We just 1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation are deleted, using a special sentinel value `dne_sentinel` to indicate the attribute initially did not exist before tracing; 2. in pjit.py's `_get_states`, when reading initial attr values before the pjit_p bind, if the attribute does not exist we don't try to read it and instead just use `dne_sentinel` as the value, which is a convenient empty pytree; 3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based on the current attr states, we map attrs that don't exist to `dne_sentinel` (rather than just erroring when the attr doesn't exist, as before). In short, we use a special value to indicate "does not exist". If `jax_getattr` supported the 'default' argument, the code would be a little cleaner since we could avoid the `if hasattr` stuff. And that's probably a useful feature to have anyway. We can add that in a follow-up. This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan etc in follow-ups.
PiperOrigin-RevId: 740280136
…ult layout inference. default_vector_size is initialized with `math.inf` and is never `None`. PiperOrigin-RevId: 740283678
http://github.com/openxla/xla/commit/d505fef9c5eb6cc1bf282fdf62139783d7fe4ec5. PiperOrigin-RevId: 740293121
We need to include the type caster for std::string_view if we use nb::cast<std::string_view>. PiperOrigin-RevId: 740311318
PiperOrigin-RevId: 740312383
PiperOrigin-RevId: 740318556
OSS Jax builds for GPU backends split `jaxlib` into three wheels and since we cannot expect a stable C++ ABI among the shared libraries, we refactor to ensure: 1. C++ objects are not created/consumed by different shared libraries. 2. Static objects are declared and defined appropriately. This PR: 1. Migrates Jax XLA FFI callback handlers from XLA's Internal FFI API to the [External FFI API](https://github.com/openxla/xla/tree/main/xla/ffi#xla-ffi-external-vs-internal-apis). Note that we update both CPU and GPU handlers because we cannot mix Internal and External APIs. 2. Updates how FFI GPU handlers are registered, now analogous to how the original GPU custom call was registered. 3. Adds an `xla::ffi::ExecutionContext` member to `ifrt::PjRtLoadedExectuable` holding opaque pointers to callbacks. 4. Updates Jax `callback.py` to call the new FFI callback handlers. PiperOrigin-RevId: 740327296
Also in passing fix up some header guards and authorship comments. PiperOrigin-RevId: 740337166
…mmits/cdb53266e6c251d91a2c321d64e8466caff129a9) PiperOrigin-RevId: 740345806
…ne and WG thread semantics PiperOrigin-RevId: 740371195
PiperOrigin-RevId: 740371651
PiperOrigin-RevId: 740379562
…_mode6 PiperOrigin-RevId: 740381115
PiperOrigin-RevId: 740416272
…allas_1 PiperOrigin-RevId: 740417766
PiperOrigin-RevId: 740422195
PiperOrigin-RevId: 740428623
Introduces a new `download-jax-only-from-gcs` variable to the workflow configs. When set to 1, the test workflows will only download and install the `jax` wheel. Other artifacts such as the latest releases of `jaxlib` and the CUDA plugin dependencies will be downloaded and installed from PyPI. PiperOrigin-RevId: 740430538
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream