Skip to content

Ci upstream sync 168 1 reb #357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 487 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
487 commits
Select commit Hold shift + click to select a range
e6b9250
Replace uses of deprecated `Shape::rank()` with:
Google-ML-Automation Mar 28, 2025
789f8e5
[Mosaic GPU] Ignore layouts that are already set when computing defau…
dimitar-asenov Mar 28, 2025
2b561f8
Marked as thread_unsafe_test:
vfdev-5 Mar 28, 2025
166b269
Add a small atol bump to `betainc` test in `LaxVmapOpTest`
ayaka14732 Mar 28, 2025
619969a
Fix error in build.py when trying to build aarch64 jaxlib wheel.
Google-ML-Automation Mar 28, 2025
98e947b
Use a 16 core Windows runner when building artifacts
nitins17 Mar 28, 2025
111a5d0
Remove get_emit_python_callback_descriptor from the type stubs.
hawkinsp Mar 28, 2025
eabcf41
Remove a use of XlaComputation from call_tf.
hawkinsp Mar 28, 2025
61b3a99
Change the `step counter` to an `init flag`
Google-ML-Automation Mar 28, 2025
7af7bcb
Make sure `vma` on ShapedArray exists by default to make development …
yashk2810 Mar 28, 2025
a466fa0
Fixed failing `ExcessPrecisionTest.test_matmul_f32_out_simple` test.
mwhittaker Mar 28, 2025
bdd3667
[array API] make capabilities more accurate
jakevdp Mar 26, 2025
f2eb52e
Clean up: num_groups = num_q_heads // num_kv_heads
Google-ML-Automation Mar 28, 2025
f0fb8fe
Move aval_to_xla_shape into callback.py, which is its only user.
hawkinsp Mar 28, 2025
aecda56
[pallas:mosaic_gpu] `GPUMesh` now accepts axis names in a more struct…
superbobry Mar 28, 2025
026f98d
[pallas:mosaic_gpu] Added support for collective GMEM->SMEM copies to…
superbobry Mar 28, 2025
2079ac0
Set NB_DOMAIN=jax
hawkinsp Mar 28, 2025
cc2e9e5
implement nbytes for PRNGKeyArray
ZacCranko Mar 28, 2025
3f5efeb
[NFC] Fix linter errors in pipeline file
Google-ML-Automation Mar 28, 2025
2c6569b
cleanup now that we depend on ml_dtypes>=0.5
jakevdp Mar 28, 2025
edc1c43
scan: improve docs & errors around dynamic length
jakevdp Mar 28, 2025
9a5134a
Use f32 scratch for output so we only need to transfer output with de…
Google-ML-Automation Mar 28, 2025
94a7a54
Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter…
yashk2810 Mar 28, 2025
47cad6f
PR #27576: [attrs] experimental appendattr
mattjj Mar 28, 2025
b4b792d
Remove GPU-specific dependencies from backend-independent tests.
Google-ML-Automation Mar 28, 2025
d75908a
add discord release action
ZacCranko Mar 20, 2025
8edb7af
Add vma rules for pmin and pmax
yashk2810 Mar 28, 2025
3b2b705
DOC: add documentation note about default dtypes
jakevdp Mar 28, 2025
e90cba8
Fix an edge-case in reshape sharding rule where the last splitting/me…
yashk2810 Mar 29, 2025
8302460
Update XLA dependency to use revision
Google-ML-Automation Mar 29, 2025
f0bdc49
Update XLA dependency to use revision
Google-ML-Automation Mar 30, 2025
f1acbb9
[mgpu] Register the mosaic_gpu dialect regardless of warpgroup/lane l…
cperivol Mar 30, 2025
a46a72f
[mgpu/pallas] Expose WGMMA_TRANSPOSED layout
cperivol Mar 30, 2025
c254e51
[Pallas:MGPU] Only allow small tiling in Pallas programs
apaszke Mar 31, 2025
9e0e9b2
jax.core: finalize a number of deprecations for JAX v0.6.0
jakevdp Mar 28, 2025
fbc1528
[pallas:mgpu] Allow more freedom for the user to transform references.
cperivol Mar 31, 2025
21ba6c3
[Pallas:MGPU] Remove (now) unnecessary TransposeTransforms
apaszke Mar 31, 2025
9762027
Update XLA dependency to use revision
Google-ML-Automation Mar 31, 2025
574e399
[pallas:mosaic_gpu] Run all Mosaic GPU-specific tests under WG semantics
superbobry Mar 31, 2025
faa4840
[jaxlib] Pack/unpack subbyte types to/from numpy arrays to support in…
danielsuo Mar 31, 2025
c76b8ad
Add Jax tracing micro benchmarks.
danielsuo Mar 31, 2025
2a02933
Remove legacy GPU kernel for LU decomposition.
dfm Mar 31, 2025
fa22da4
[pallas:mosaic_gpu] Fixed lane-level lowering of `lax.optimization_ba…
superbobry Mar 31, 2025
a4195fc
[array api] return all devices in devices()
jakevdp Mar 31, 2025
2b1ea58
Propagate sharding and vma rule for axis_index_p. There's no need for…
yashk2810 Mar 31, 2025
14b7bd8
Bump actions/setup-python from 5.4.0 to 5.5.0
dependabot[bot] Mar 31, 2025
1776656
Minor docstring updates for AOT wrappers in error checking
ayaka14732 Mar 31, 2025
87408a4
__jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_…
jakevdp Mar 31, 2025
9f4594a
jnp.power: support __jax_array__ on inputs
jakevdp Mar 27, 2025
da3b53c
[Pallas TPU] Remove forward compatibility code for float -> signed co…
apaszke Apr 1, 2025
210e8fb
[Easy] Make pallas mesh grid handling more resilient to tuple names.
Google-ML-Automation Apr 1, 2025
5d38aa1
Update XLA dependency to use revision
Google-ML-Automation Apr 1, 2025
68dd0e0
Create the test targets for the wheel size verification.
Google-ML-Automation Apr 1, 2025
5687003
Add scan_p and cond_p vma rule.
yashk2810 Apr 1, 2025
28eb246
Bump tsickert/discord-webhook from 5.3.0 to 7.0.0
dependabot[bot] Mar 31, 2025
de44f26
Update permisisons community_release_actions.yml
ZacCranko Apr 1, 2025
a276db3
AutoPGLE: force-disable graphs less
olupton Mar 5, 2025
2d8adf4
Remove the try/except for Shardy imports.
belitskiy Apr 1, 2025
c29e024
[jaxlib] Roll back subbyte types due to failing asan tests.
danielsuo Apr 1, 2025
d421797
jnp.select: support __jax_array__ for inputs
jakevdp Apr 1, 2025
33b0181
jnp.einsum: add support for __jax_array__
jakevdp Apr 1, 2025
1df9bec
Enable test_scan_offload in memories_test.
Google-ML-Automation Apr 1, 2025
b6c02a8
[Mosaic-GPU] [2/3] Add NVSHMEM support to Mosaic-GPU custom call
nvcastet Mar 20, 2025
3ddc24d
Clarify documentation of jnp.heaviside
LouisJustinTALLOT Apr 1, 2025
d29b8ab
[mutable-arrays] add vmap rule for mutable_array_p, very basic test
mattjj Apr 1, 2025
adfe6d5
add reduction support in copy_smem_to_gmem
Amir-19 Mar 19, 2025
bf994f5
Add OOB checks to jax.numpy array indexing
ayaka14732 Apr 2, 2025
89b4c82
cumulative reductions: support __jax_array__ on inputs
jakevdp Apr 1, 2025
5e048f8
[better_errors] Fix the handling of kwargs for debug_info.
gnecula Mar 26, 2025
4de708c
upgrade docs from `jax.core` to `jax.extend.core` where needed to fix…
froystig Apr 2, 2025
c395ad2
make random_gamma_grad not a primitive anymore
mattjj Apr 1, 2025
879d435
Remove nanobind pin now that nanobind fix landed.
danielsuo Apr 2, 2025
898fa1b
[jaxlib] Fix asan tests for subbyte types in CPU/GPU callbacks.
danielsuo Apr 2, 2025
3ff00e0
Prepare for disallowing `jnp.array(None)`
superbobry Apr 2, 2025
07beae0
Update XLA dependency to use revision
Google-ML-Automation Apr 2, 2025
ddf9aea
Relax the aval check in `select_hlo_lowering_opaque` to only check fo…
yashk2810 Apr 2, 2025
cc631d3
Remove unused Attrs from `lu_pivots_to_permutation` FFI kernel.
dfm Apr 2, 2025
64698b3
Remove legacy GPU kernels for QR decomposition.
dfm Apr 2, 2025
18a5c86
Remove the extra stack frame that was introduce in uniform due to dro…
yashk2810 Apr 2, 2025
c23de3d
Updates for 3.14
vfdev-5 Jan 16, 2025
f2ddab7
Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1
vfdev-5 Apr 2, 2025
acf008a
jnp.isinf & friends: support __jax_array__
jakevdp Apr 1, 2025
1c7b36a
let XLA metadata be unset in nested dynamic scopes
froystig Apr 2, 2025
f1f4be0
Add pbroadcast insertion for `psum_p` in the traceable. This effectiv…
yashk2810 Apr 2, 2025
f0b6edc
Bump actions/cache from 4.2.0 to 4.2.3
dependabot[bot] Apr 2, 2025
01f3de8
Add simple vmap support for lax.ragged_all_to_all.
ghpvnist Apr 2, 2025
2c25128
jnp.concat and friends: support __jax_array__
jakevdp Apr 2, 2025
5163fbb
Support `conv` `unfused_flops` in roofline.
zacmustin Apr 2, 2025
3a74c76
[array_api] update array_api_version to 2024.12
jakevdp Apr 2, 2025
5ae91b3
`jnp.array` no longer accepts None
superbobry Apr 2, 2025
8c55b0b
Add CI workflow for JAX distibuted initialize in K8s jobsets
yhtang Sep 23, 2024
69ddad8
address review comments
yhtang Mar 17, 2025
bf89a31
update ratchet action pin
yhtang Mar 17, 2025
b7edb55
[pallas] Removed `pl.device_id`. Use `lax.axis_index` instead.
superbobry Apr 2, 2025
6fbed1e
Fix custom_transpose when composed with custom_jvp and use_direct_lin…
dfm Apr 2, 2025
71f9bcb
add an `out_sharding` option to `jax.random.bits`
froystig Apr 2, 2025
30e0803
[Mosaic GPU] Define the `mosaic_gpu.custom_primitive` dialect op.
bchetioui Apr 3, 2025
e9d032d
Prune passthrough outputs in lax.switch.
danielsuo Mar 12, 2025
8904882
use common `maybe_auto_axes` helper in `random.uniform`
froystig Apr 3, 2025
bd4df4a
add an `out_sharding` option to `jax.random.randint`
froystig Apr 3, 2025
eaeffbd
[CI] Enable nightly TPU CI tests for v6e.
MichaelHudgins Apr 3, 2025
7e5d47d
[pallas:mosaic_gpu] Slightly reworded the docstrings for a few recent…
superbobry Apr 3, 2025
3535508
[pallas:mosaic_gpu] `emit_pipeline*` now passes the loop indices into…
superbobry Apr 3, 2025
001b4d4
[Mosaic GPU] Get rid of `LayoutAttr` and related comments.
bchetioui Apr 3, 2025
9794ba9
Update XLA dependency to use revision
Google-ML-Automation Apr 3, 2025
5f78ff3
docs: compilation_cache_expect_pgle option
olupton Apr 1, 2025
14c0224
Restrict the regex for copying the wheels.
Google-ML-Automation Apr 3, 2025
f9403f7
Moved the `jax.Array` baseclass to C++
superbobry Apr 3, 2025
24ab517
Bump up tolerance in ShardMapSystematicTest.test_vmap_closure for GPUs.
belitskiy Apr 3, 2025
99b8959
Not to use dynamic grid in the ragged paged attention Pallas kernel.
Google-ML-Automation Apr 3, 2025
775ce4d
Straight-through estimator for nvfp4
wenscarl Mar 20, 2025
dca4e8c
Improve based on review 1
wenscarl Mar 20, 2025
7f0c6bc
Improve based on review 2
wenscarl Mar 27, 2025
51b91f9
Add optimization barrier.
wenscarl Mar 31, 2025
2a3cce4
Fix problem finding clang++ when building JAX via build.py on windows.
hawkinsp Apr 3, 2025
e614019
Only trigger K8s CI on changes to cluster config and distributed init…
yhtang Apr 3, 2025
a81b07e
[Mosaic GPU] Fix index_invariant slot in warp-specialized pipeline.
justinjfu Apr 3, 2025
87b222d
jax.numpy: support __jax_array__ in several more functions
jakevdp Apr 2, 2025
569fc1d
Eliminate DeprecationWarning in python3.12+ in jax pallas for ~.
Google-ML-Automation Apr 3, 2025
4a1a778
[export] Add support for override_lowering_rules to jax.export.
gnecula Apr 3, 2025
4caf626
[Mosaic GPU] Re-enable WS pipelined copy test.
justinjfu Apr 3, 2025
a77d8ba
[pallas:mgpu] General ref transform handling at lowering time.
cperivol Apr 3, 2025
a52cf11
[mgpu:pallas] Changes to allow the use of WGMMA_TRANSPOSED_LAYOUT.
cperivol Apr 3, 2025
4f0bd10
add an `out_sharding` option to `jax.random.permutation`
froystig Apr 3, 2025
56b9e69
[pallas:mgpu] Initial version of inline_mgpu op
cperivol Apr 4, 2025
3a241e7
Delete `PjRtClient.Defragment`.
zacmustin Apr 4, 2025
07cb682
require `out_shardings` as a keyword-only argument on public functions
froystig Apr 4, 2025
392b56e
Add `auto_axes`, `explicit_axes` and `manual_axes` properties to Mesh…
yashk2810 Apr 4, 2025
4ea85fc
[Mosaic TPU] Allow specify priority in enqueueDMA.
bythew3i Apr 4, 2025
6e039cd
[Pallas TPU] Support DMA priority in async copy start
bythew3i Apr 4, 2025
54f6182
[pallas:mosaic_gpu] Do not specify the default `index_map` in tests
superbobry Apr 4, 2025
39ce4e8
[Mosaic GPU] Return the combined softmax residuals.
Rifur13 Apr 4, 2025
814b809
Avoid double buffering when no windowing info is present.
Google-ML-Automation Apr 4, 2025
f352afd
[pallas:mosaic_gpu] Added pretty printing to primitives consuming refs
superbobry Apr 4, 2025
4a9c62b
[Mosaic GPU] Don't force TiledLayout.lane_dims to partition data
apaszke Apr 4, 2025
9d4ecae
Fixed deadlock in NamedSharding ctor
vfdev-5 Apr 3, 2025
4e30f3c
[Mosaic GPU] Allow replicating data over warps
apaszke Apr 4, 2025
7779ed0
[pallas:mgpu] Check that swizzle dim is not transposed in copy_smem_t…
cperivol Apr 4, 2025
b88778f
[mgpu:pallas] Typo in `UnswizzleRef.untransform_reshape()` check.
cperivol Apr 4, 2025
e807f18
Update XLA dependency to use revision
Google-ML-Automation Apr 4, 2025
8b72599
[mgpu:pallas] Swizzle elements computed using bitwidth rather than by…
cperivol Apr 4, 2025
f87b82c
[mgpu] Foreach to handle scalar registers in fragmented arrays.
cperivol Apr 4, 2025
d5bf612
[mgpu:pallas] Fix swizzling check bug where it was comparing w/ #byte…
cperivol Apr 4, 2025
c9faf78
`_attempt_rewriting_take_via_slice()`: canonicalize the slice index b…
Google-ML-Automation Apr 4, 2025
c7593ea
add an `out_sharding` option to `jax.random.truncated_normal`
froystig Apr 4, 2025
f4b6b27
Always force synchronous pipelining when we have vmem storage and tri…
Google-ML-Automation Apr 4, 2025
9e2e7d9
Parameterize the random tests taking out_sharding argument in pjit_te…
yashk2810 Apr 4, 2025
47d8683
fix(docs): corrected the name of the function call in the document
Qazalbash Apr 4, 2025
cd81a57
Fix a possible race in pjit.cc.
hawkinsp Apr 4, 2025
c45547c
PR #26906: [jax.distributed] Allow explicitly setting slice_index
gspschmid Apr 4, 2025
0a17035
Don't set `memory_kind` to `None` if the mesh is AbstractMesh and the
yashk2810 Apr 4, 2025
fb8343b
[Mosaic GPU] Limit the maximum number of registers per thread to 255.
Rifur13 Apr 4, 2025
40c735d
[pallas:mosaic_gpu] Fixed a typo in `_barrier_arrive_pp_eqn`
superbobry Apr 4, 2025
71730c9
Check that memory_kind of an aval is always None
yashk2810 Apr 5, 2025
4a999b0
Update XLA dependency to use revision
Google-ML-Automation Apr 5, 2025
8368b40
Update XLA dependency to use revision
Google-ML-Automation Apr 6, 2025
9e21bd5
Automated Code Change
Google-ML-Automation Apr 6, 2025
d99a771
Automated Code Change
Google-ML-Automation Apr 6, 2025
2f94197
Automated Code Change
Google-ML-Automation Apr 6, 2025
bacfe65
Automated Code Change
Google-ML-Automation Apr 6, 2025
912e1e3
Automated Code Change
Google-ML-Automation Apr 6, 2025
b3ec162
Fix deadlock when computing cached Sharding::type() values.
hawkinsp Apr 6, 2025
8008fb7
Raise an error if the type passed to `axis_types` argument of `Mesh` …
yashk2810 Apr 7, 2025
c6744b2
[Mosaic GPU] Support Slice and Transpose in the Pallas WGMMA lowering
dimitar-asenov Apr 7, 2025
7e4a367
Use `contextlib.nullcontext` instead of `trivial_ctx`
superbobry Apr 7, 2025
34a60b3
Fix some test timeouts
apaszke Apr 7, 2025
dfb267e
Removed unused deprecations
superbobry Apr 7, 2025
e478055
[Mosaic GPU] Delete mentions of `WGMMARowFragLayout` in `layouts.py`.
bchetioui Apr 7, 2025
353ecef
Add a missing jaxlib version check in Pallas TPU lowering
apaszke Apr 7, 2025
367c2c8
Add more TSAN skips to avoid timeouts
apaszke Apr 7, 2025
e7ade2b
`jex.core.Var` is no longer ordered
superbobry Apr 7, 2025
46c91ff
Fix a race in pjit under free threading.
hawkinsp Apr 7, 2025
b08db9c
Update XLA dependency to use revision
Google-ML-Automation Apr 7, 2025
45801a8
[Mosaic GPU] Add missing to/from tiled layout attributes with replica…
dimitar-asenov Apr 7, 2025
05dc947
Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}.
hawkinsp Apr 7, 2025
27aea57
Temporarily skip JaxNumpyErrorTests in multi-thread environments
ayaka14732 Apr 7, 2025
f1c26a4
Reverts 735cec18cb2f8dff2aea5e503fd886a37aee094e
danielsuo Apr 7, 2025
2069eb3
Deprecate public export of mlir.custom_call.
dfm Apr 7, 2025
40a10ff
Removed unused `jax_remat_opt_barrier` config option
superbobry Apr 7, 2025
8e9e9d7
Removed deprecated `jax.core.{full_lower,jaxpr_as_fun,lattice_join}`
superbobry Apr 7, 2025
0f0f438
Apply forwarding in pjit linearization rule to avoid intermediate cop…
dfm Apr 4, 2025
9f1fe0a
Add int4, uint4 to test_util.suppported_types
jburnim Apr 7, 2025
7a631e0
[Pallas Fuser] Add output_fusion_mask support
sharadmv Apr 7, 2025
56002e7
[Mosaic TPU] FWD compatibility needs to keep previous version at leas…
bythew3i Apr 7, 2025
b2c1945
Add scalar event logging function
jeffcarp Feb 27, 2025
56bebde
Update XLA dependency to use revision
Google-ML-Automation Apr 7, 2025
41a33c6
[CI] Temporarily disable TPU v6 due to runner issues
MichaelHudgins Apr 7, 2025
b209309
[shard-map] in eager shmap, handle all rep rule output cases
mattjj Apr 7, 2025
61d142d
[mgpu] Allow bf16 printing
cperivol Apr 7, 2025
c1f3b09
Add jaxlib_extension_version guard against explicit copying
pschuh Apr 7, 2025
83fe4b6
Bump medyagh/setup-minikube from 0.0.18 to 0.0.19
dependabot[bot] Apr 7, 2025
2e212ce
Migrate jaxlib to use a single common .so file for all C++ dependencies.
hawkinsp Apr 7, 2025
2f19b24
[Mosaic GPU] Add scaffolding for a new lowering "axis" (UserThreadSem…
justinjfu Apr 7, 2025
ccb92a1
Relax jax dependency constraints to be able to install RC wheels
nitins17 Apr 7, 2025
1d661f2
Reverts 006a6a63feb64bf9984526030ba008186d69d2b4
Google-ML-Automation Apr 7, 2025
73ada01
jax.numpy: support __jax_array__ in remaining APIs
jakevdp Apr 7, 2025
61aa365
Removed `data_dependent_tracing_fallback` config option
superbobry Apr 7, 2025
988b9bb
Add **experimental** `with_dll_constraint` API. This is for cases whe…
yashk2810 Apr 7, 2025
da72c18
Add custom pretty print rule for the unary ops with accuracy s.t. acc…
hanrach9 Apr 7, 2025
927dc61
Apply output forwarding in lin rule for pjit.
dfm Apr 7, 2025
4c6b036
harden cache against jaxlib ver
ZacCranko Apr 7, 2025
322632f
Migrate custom_call filecheck to use internal custom_call since the e…
dfm Apr 8, 2025
d3f27db
Remove unused function jax._src.interpreters.mlir.xla_computation_to_…
hawkinsp Apr 8, 2025
03e7cb8
Address previous FP8-related TODOs in jaxlib/XLA.
apivovarov Apr 8, 2025
1b918dc
[export] Add backwards compatibility test for annotate_device_placement.
gnecula Apr 8, 2025
5ddff5f
[Mosaic GPU] Add support for replicated warp_dim parsing and a dedica…
dimitar-asenov Apr 8, 2025
3aacfce
[export] Add support for serializing functions with PRNG keys as inpu…
gnecula Apr 7, 2025
82d5d6e
Removed `jax._src.raise_to_shaped`
superbobry Apr 8, 2025
2ade7da
Removed redundant `pass`es
superbobry Apr 8, 2025
cfd9e51
[Mosaic GPU] Refactor and generalize code in `optimization_barrier`.
dimitar-asenov Apr 8, 2025
e8dcd37
[Mosaic GPU] Add warpgroup lowering for `RunState` in Pallas.
dimitar-asenov Apr 8, 2025
1264f07
Removed `eager_pmap` config option
superbobry Apr 8, 2025
199fa7d
[pallas:mosaic_gpu] `emit_pipeline*` now allows the grid to be dynamic
superbobry Apr 8, 2025
f12a4af
Remove unused `return wrapper` in annotate_function that creates a se…
Google-ML-Automation Apr 8, 2025
dad3e7d
Add a skeleton for Pallas:Mosaic GPU documentation
apaszke Apr 8, 2025
aaef00f
Update XLA dependency to use revision
Google-ML-Automation Apr 8, 2025
72e420c
[Mosaic GPU] Simplify load/store methods now that we have fewer layouts
apaszke Apr 8, 2025
cb8d42c
[pallas:mosaic_gpu] Added test for custom pretty-printing rules
superbobry Apr 8, 2025
2c17538
[Mosaic TPU] Add support for non-32bit types in vector.extract
apaszke Apr 8, 2025
c4340d9
Replace references to jax.readthedocs.io with docs.jax.dev.
hawkinsp Apr 8, 2025
4b4f828
[mutable-arrays] limit implicit ref_swap dtype promotion
mattjj Apr 8, 2025
54e8df4
[shard-map] add while_map rep rule
mattjj Apr 7, 2025
4e09452
[Mosaic TPU] Add MemRead and MemStore effects to load and store ops.
bythew3i Apr 8, 2025
67df04c
[shard-map] canonicalize rep=None to be rep={all possible axes}
mattjj Apr 7, 2025
689f766
change tack...
mattjj Apr 7, 2025
84210a3
jnp.repeat: don't cast repeats to array, as they must be static.
jakevdp Apr 8, 2025
6965601
jnp.linalg: add symmetrize_input argument & docs
jakevdp Apr 7, 2025
872a43d
Clarify jax.make_jaxpr docstring
j-towns Apr 8, 2025
7fd3a07
Simplify handling of type stubs in nanobind extension rules.
hawkinsp Apr 8, 2025
8d71106
Added `jax.no_tracing` to the API docs
superbobry Apr 8, 2025
879b72a
Remove reexports of ml_dtypes types from xla_client.py.
hawkinsp Apr 8, 2025
686144b
Finalize deprecation of `ffi_call` with inline arguments.
dfm Apr 8, 2025
d094562
Split weakref_lru_cache into its own extension.
hawkinsp Apr 8, 2025
8a6bfd6
Make changes to shard_map to prepare for setting `varying_axes_in_typ…
yashk2810 Apr 8, 2025
185d65f
[ragged-paged-attn] Unify kv strided load to one.
bythew3i Apr 8, 2025
bfe7923
Enable public doc for scaled dot
kaixih Mar 26, 2025
56d13e0
Remove asserts
kaixih Apr 3, 2025
8993c0e
format
kaixih Apr 3, 2025
8cd2843
[JAX] Remove calls to xla_computation_to_mlir_module.
hawkinsp Apr 8, 2025
ba0879a
Disable second order vjp tests in RunStateHypothesisTest.test_vjp if …
vfdev-5 Mar 11, 2025
ef84e9d
Rename `pbroadcast` to `pvary` and expose it as `jax.lax.pvary`.
yashk2810 Apr 8, 2025
48fcf02
Rename `psum2` to `psum_invariant` and put it in `lax_parallel`. We s…
yashk2810 Apr 9, 2025
9af8e04
Fix typo in the error message
yashk2810 Apr 9, 2025
16d737b
Account for versioned clang binaries
charleshofer Apr 10, 2025
3b4a7b0
Make Clang use manylinux C++ standard library
charleshofer Apr 11, 2025
248e638
Move clang gcc path options to config file
charleshofer Apr 14, 2025
98baf09
Add clang config file
charleshofer Apr 14, 2025
ea55e59
Use .cfg file
charleshofer Apr 14, 2025
9de75a2
Trivial change for CI
charleshofer Apr 14, 2025
6985f0d
Remove 6.2.4 build
charleshofer Apr 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
18 changes: 10 additions & 8 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ build -c opt
build --output_filter=DONT_MATCH_ANYTHING

build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
build --copt=-DNB_DOMAIN=jax

# #############################################################################
# Platform Specific configs below. These are automatically picked up by Bazel
Expand Down Expand Up @@ -97,6 +98,7 @@ build:windows --incompatible_strict_action_env=true
# #############################################################################
build:nonccl --define=no_nccl_support=true

build --repo_env USE_PYWRAP_RULES=1
build:posix --copt=-fvisibility=hidden
build:posix --copt=-Wno-sign-compare
build:posix --cxxopt=-std=c++17
Expand Down Expand Up @@ -138,13 +140,13 @@ build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda

# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This config is used for building targets with CUDA libraries from stubs.
Expand Down Expand Up @@ -262,8 +264,8 @@ build:ci_darwin_arm64 --color=yes
# Windows x86 CI configs
build:ci_windows_amd64 --config=avx_windows
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
build:ci_windows_amd64 --color=yes

Expand Down Expand Up @@ -331,9 +333,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst
build:rbe_windows_amd64 --config=rbe

# Set the host, execution, and target platform
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"

build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
build:rbe_windows_amd64 --enable_runfiles
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/bazel_cuda_non_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest"

env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
Expand Down Expand Up @@ -79,6 +79,7 @@ jobs:
continue-on-error: true
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
Expand Down
65 changes: 65 additions & 0 deletions .github/workflows/bazel_optional_cuda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: CI - Bazel Optional CUDA tests
on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
pull_request:
branches:
- main
types: [ labeled, synchronize ]
schedule:
- cron: "0 */2 * * *" # Run once every 2 hours
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
run_tests:
if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }}
runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest'
strategy:
matrix:
# Optional gpus to run against
runner: ["linux-x86-a4-224-b200-1gpu"]
name: "Bazel single accelerator CUDA tests (${{ matrix.runner }})"
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA Tests
run: |
nvidia-smi
bazel test --config=ci_linux_x86_64_cuda \
--config=resultstore \
--config=rbe_cache \
--repo_env=HERMETIC_CUDA_VERSION="12.8.0" \
--repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \
--repo_env=HERMETIC_PYTHON_VERSION="3.13" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--run_under "$(pwd)/build/parallel_accelerator_execute.sh" \
--test_output=errors \
--test_env=JAX_ACCELERATOR_COUNT=1 \
--test_env=JAX_TESTS_PER_ACCELERATOR=32 \
--local_test_jobs=32 \
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
--test_tag_filters=-multiaccelerator \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64="1" \
--action_env=NCCL_DEBUG=WARN \
--color=yes \
//tests:gpu_tests //tests:backend_independent_tests \
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
12 changes: 6 additions & 6 deletions .github/workflows/build_artifacts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ on:
default: "linux-x86-n2-16"
options:
- "linux-x86-n2-16"
- "linux-arm64-c4a-64"
- "windows-x86-n2-64"
- "linux-arm64-t2a-48"
- "windows-x86-n2-16"
artifact:
description: "Which JAX artifact to build?"
type: choice
Expand Down Expand Up @@ -119,11 +119,11 @@ jobs:

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Enable RBE if building on Linux x86
if: contains(inputs.runner, 'linux-x86')
- name: Enable RBE if building on Linux x86 or Windows x86
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86
if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86')
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64
if: contains(inputs.runner, 'linux-arm64')
run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV
# Halt for testing
- name: Wait For Connection
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python 3.11
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: 3.11
- run: python -m pip install pre-commit
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
- uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
Expand All @@ -64,7 +64,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -102,7 +102,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -130,7 +130,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -156,7 +156,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -187,7 +187,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: 3.12
- name: Install JAX
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ jobs:
matrix:
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
Expand Down
34 changes: 34 additions & 0 deletions .github/workflows/community_release_actions.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Release Actions

on:
release:
types: [published]

permissions:
contents: read

jobs:
discord_release:
if: github.repository_owner == 'jax-ml'
runs-on: ubuntu-latest
steps:
- name: Get release URL
id: get-release-url
run: |
URL="https://docs.jax.dev/en/latest/changelog.html"
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1
id: get-content
with:
stringToTruncate: |
JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!

${{ github.event.release.body }}
maxLength: 2000
truncationSymbol: "..."
- name: Discord Webhook Action
uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}
4 changes: 2 additions & 2 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04
ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
109 changes: 109 additions & 0 deletions .github/workflows/k8s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
name: Distributed run using K8s Jobset

on:
push:
branches:
- main
paths:
- 'jax/distributed.py'
- 'jax/_src/distributed.py'
- 'jax/_src/clusters/**'
pull_request:
branches:
- main
paths:
- 'jax/distributed.py'
- 'jax/_src/distributed.py'
- 'jax/_src/clusters/**'

permissions:
contents: read

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -ex -o pipefail {0}

jobs:

distributed-initialize:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4
with:
path: jax

- name: Start Minikube cluster
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/[email protected]

- name: Install K8s Jobset
run: |
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml

- name: Build image
run: |
cat > Dockerfile <<EOF
FROM ubuntu:22.04
ADD jax /opt/jax
RUN apt-get update && apt-get install -y python-is-python3 python3-pip
RUN pip install -e /opt/jax[k8s]
EOF

minikube image build -t local/jax:latest .

- name: Create service account for K8s job introspection
run: |
kubectl apply -f jax/examples/k8s/svc-acct.yaml

- name: Prepare test job
run: |
export VERSION=v4.44.3
export BINARY=yq_linux_amd64
wget https://github.com/mikefarah/yq/releases/download/${VERSION}/${BINARY} -O /usr/bin/yq && chmod +x /usr/bin/yq

cat jax/examples/k8s/example.yaml |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].image = "local/jax:latest"' |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].imagePullPolicy = "Never"' |\
tee example.yaml

- name: Submit test job
run: |
kubectl apply -f example.yaml

- name: Check job status
shell: bash -e -o pipefail {0}
run: |
while true; do
status=$(kubectl get jobset example -o yaml | yq .status.conditions[0].type)
timestamp=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$timestamp] Checking job status..."

if [ "$status" == "Completed" ]; then
echo "[$timestamp] Job has completed successfully!"
exit 0
elif [ "$status" == "Failed" ]; then
echo "[$timestamp] Job has failed!"
exit 1
else
echo "[$timestamp] Job is still running. Current pod status:"
kubectl get pods --no-headers
echo "[$timestamp] Waiting for 3 seconds before checking again..."
sleep 3
fi
done

- name: Examine individual pod outputs
if: "!cancelled()"
run: |
set +x
kubectl get pods --no-headers | awk '{print $1}' | while read -s pod; do
echo "========================================"
echo "Pod $pod output:"
echo "----------------------------------------"
kubectl logs $pod
echo "========================================"
done
Loading
Loading