Skip to content

CI: 04/16/25 upstream sync #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 785 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
785 commits
Select commit Hold shift + click to select a range
b18dc1d
[Mosaic GPU] Add scaffolding for a new lowering "axis" (UserThreadSem…
justinjfu Apr 7, 2025
64e4bf2
Relax jax dependency constraints to be able to install RC wheels
nitins17 Apr 7, 2025
3a3c145
[shard-map] canonicalize rep=None to be rep={all possible axes}
mattjj Apr 7, 2025
48a9ad0
Reverts 006a6a63feb64bf9984526030ba008186d69d2b4
Google-ML-Automation Apr 7, 2025
3420546
Merge pull request #27716 from jakevdp:jax-array
Google-ML-Automation Apr 7, 2025
2944e3b
Removed `data_dependent_tracing_fallback` config option
superbobry Apr 7, 2025
0a72e85
Add **experimental** `with_dll_constraint` API. This is for cases whe…
yashk2810 Apr 7, 2025
84e04fe
Add custom pretty print rule for the unary ops with accuracy s.t. acc…
hanrach9 Apr 7, 2025
ca6e470
harden cache against jaxlib ver
ZacCranko Apr 7, 2025
9e03686
Merge pull request #27793 from dfm:lin-out-fwd
Google-ML-Automation Apr 8, 2025
4bae9cd
Merge pull request #27814 from ZacCranko:harden-cache
Google-ML-Automation Apr 8, 2025
3158996
Migrate custom_call filecheck to use internal custom_call since the e…
dfm Apr 8, 2025
86de478
Remove unused function jax._src.interpreters.mlir.xla_computation_to_…
hawkinsp Apr 8, 2025
bb515aa
Address previous FP8-related TODOs in jaxlib/XLA.
apivovarov Apr 8, 2025
51dbcd4
[export] Add backwards compatibility test for annotate_device_placement.
gnecula Apr 8, 2025
19fcae1
[Mosaic GPU] Add support for replicated warp_dim parsing and a dedica…
dimitar-asenov Apr 8, 2025
bc11a63
Clarify jax.make_jaxpr docstring
j-towns Apr 8, 2025
c2eaedf
Merge pull request #27776 from gnecula:export_keys
Google-ML-Automation Apr 8, 2025
8ed59d8
Removed `jax._src.raise_to_shaped`
superbobry Apr 8, 2025
af072fe
Removed redundant `pass`es
superbobry Apr 8, 2025
d12cbff
[Mosaic GPU] Refactor and generalize code in `optimization_barrier`.
dimitar-asenov Apr 8, 2025
c4cc94a
[Mosaic GPU] Add warpgroup lowering for `RunState` in Pallas.
dimitar-asenov Apr 8, 2025
12811f0
Removed `eager_pmap` config option
superbobry Apr 8, 2025
5f33280
[pallas:mosaic_gpu] `emit_pipeline*` now allows the grid to be dynamic
superbobry Apr 8, 2025
73ecf0b
Remove unused `return wrapper` in annotate_function that creates a se…
Google-ML-Automation Apr 8, 2025
511f782
Add a skeleton for Pallas:Mosaic GPU documentation
apaszke Apr 8, 2025
aa6e701
Merge pull request #27827 from apaszke:mgpu-docs
Google-ML-Automation Apr 8, 2025
d6524dc
Update XLA dependency to use revision
Google-ML-Automation Apr 8, 2025
b926fac
[Mosaic GPU] Simplify load/store methods now that we have fewer layouts
apaszke Apr 8, 2025
f5d73b8
[pallas:mosaic_gpu] Added test for custom pretty-printing rules
superbobry Apr 8, 2025
b8353d1
[Mosaic TPU] Add support for non-32bit types in vector.extract
apaszke Apr 8, 2025
e02faab
Replace references to jax.readthedocs.io with docs.jax.dev.
hawkinsp Apr 8, 2025
ae95797
change tack...
mattjj Apr 7, 2025
4d2808c
[mutable-arrays] limit implicit ref_swap dtype promotion
mattjj Apr 8, 2025
b7d430f
jnp.repeat: don't cast repeats to array, as they must be static.
jakevdp Apr 8, 2025
03c1bf9
Merge pull request #27803 from mattjj:27644
Google-ML-Automation Apr 8, 2025
29cb6cd
[Mosaic TPU] Add MemRead and MemStore effects to load and store ops.
bythew3i Apr 8, 2025
ef68063
Merge pull request #27809 from mattjj:26621
Google-ML-Automation Apr 8, 2025
b073e8d
Merge pull request #27836 from jakevdp:fix-repeat
Google-ML-Automation Apr 8, 2025
76825a2
Merge pull request #27807 from jakevdp:eigvalsh-symmetrize
Google-ML-Automation Apr 8, 2025
f1bcf3b
Merge pull request #27821 from j-towns:clarify-make-jaxpr-docstr
Google-ML-Automation Apr 8, 2025
a43136b
Simplify handling of type stubs in nanobind extension rules.
hawkinsp Apr 8, 2025
62df2e8
Added `jax.no_tracing` to the API docs
superbobry Apr 8, 2025
09fed2f
Remove reexports of ml_dtypes types from xla_client.py.
hawkinsp Apr 8, 2025
2d44f98
Finalize deprecation of `ffi_call` with inline arguments.
dfm Apr 8, 2025
b4629c2
Split weakref_lru_cache into its own extension.
hawkinsp Apr 8, 2025
8301c30
Make changes to shard_map to prepare for setting `varying_axes_in_typ…
yashk2810 Apr 8, 2025
5a340a9
Disable second order vjp tests in RunStateHypothesisTest.test_vjp if …
vfdev-5 Mar 11, 2025
7b45552
[ragged-paged-attn] Unify kv strided load to one.
bythew3i Apr 8, 2025
d19a458
fix docstring
Cjkkkk Apr 8, 2025
b8d9e7f
Merge pull request #27503 from kaixih:enable_doc_scaled_dot
Google-ML-Automation Apr 8, 2025
a516988
[JAX] Remove calls to xla_computation_to_mlir_module.
hawkinsp Apr 8, 2025
373ac2e
Merge pull request #27804 from vfdev-5:ft-adapt-state-test-2
Google-ML-Automation Apr 8, 2025
84016bc
Rename `pbroadcast` to `pvary` and expose it as `jax.lax.pvary`.
yashk2810 Apr 8, 2025
f95f6a8
Rename `psum2` to `psum_invariant` and put it in `lax_parallel`. We s…
yashk2810 Apr 9, 2025
4275135
Fix typo in the error message
yashk2810 Apr 9, 2025
866e32b
[pallas:mosaic_gpu] `ModuleContext.reserve_barrier` is now a context …
superbobry Apr 9, 2025
88555c2
Fix data race in weakref_lru_cache under free-threading.
hawkinsp Apr 9, 2025
6792703
Fix failing documentation tests
apaszke Apr 9, 2025
9306a7a
Avoid the problematic 2.19.2 release of tensorboard-plugin-profile
apaszke Apr 9, 2025
bbb8066
Automated Code Change
chsigg Apr 9, 2025
90cf9b4
Merge pull request #27866 from hawkinsp:race
Google-ML-Automation Apr 9, 2025
d9d2f0b
Update XLA dependency to use revision
Google-ML-Automation Apr 9, 2025
3d006e9
Fix build for clang >=19
medaminezghal Apr 9, 2025
8521c01
Merge pull request #27870 from medaminezghal:fix-clang19
Google-ML-Automation Apr 9, 2025
9adc3cc
[Mosaic GPU] Add a `LayoutCast` op to the Mosaic GPU mlir dialect.
dimitar-asenov Apr 9, 2025
5f4c6e4
Update XLA dependency to use revision
Google-ML-Automation Apr 9, 2025
e0cda84
Fix linear_call to allow recursive definitions.
dfm Apr 8, 2025
76c6b5b
More changes for enabling `vma` by default in JAX
yashk2810 Apr 9, 2025
8383af0
[shard-map] fix another bug where we incorrectly handled None in chec…
mattjj Apr 9, 2025
b6a4631
Merge tuple_replace and tuple_update in jax._src.util.
carlosgmartin Apr 9, 2025
21a4429
Merge pull request #27879 from mattjj:shmap-fix-5
Google-ML-Automation Apr 9, 2025
e750d7e
Add option for debug print to be called on partitioned arguments rath…
danielsuo Apr 9, 2025
0385667
Merge pull request #27853 from carlosgmartin:merge_tuple_update_tuple…
Google-ML-Automation Apr 9, 2025
2b3839d
[shard-map] make shard_map work with custom_jvp symbolic zeros
mattjj Apr 8, 2025
2863b48
Merge pull request #27759 from mattjj:vmappable-bind-fix
Google-ML-Automation Apr 9, 2025
c418495
Merge pull request #27886 from mattjj:26763
Google-ML-Automation Apr 9, 2025
e772a08
Fix docstrings of segment_{prod,max,min} after commit 4679f45.
arnoegw Apr 9, 2025
e9308a2
Added update_one_slot race tsan suppression for 3.14
vfdev-5 Apr 9, 2025
7c1595a
Skip jax/tests:unary_ops_accuracy_test when running with older versio…
GleasonK Apr 9, 2025
56646af
Remove xla_extension.stablehlo_to_mhlo.
hawkinsp Apr 9, 2025
97a0d75
[shard-map] add docs for VMAs
mattjj Apr 9, 2025
fbca090
Merge pull request #27893 from mattjj:shmap-vma-docs
Google-ML-Automation Apr 9, 2025
75e4279
Set `jax_varying_axes_in_types` to True by default.
yashk2810 Apr 9, 2025
178b2f1
Merge pull request #27888 from vfdev-5:add-314-cpython-update_one_slo…
Google-ML-Automation Apr 9, 2025
713ea3c
[JAX] Remove deprecated exports in jax.lib.xla_client.
hawkinsp Apr 9, 2025
c5bd13b
Refactor random_lax_test.py
jakevdp Apr 9, 2025
9320214
Make printing work with shard_map after vma has been switched on
yashk2810 Apr 10, 2025
a505520
Fix build breakage from missing _sdy_enums_gen.py.
hawkinsp Apr 10, 2025
2ceb97c
Restructure WeakrefLRUCache.
hawkinsp Apr 10, 2025
892cb65
[shard-map] good errors for pvary issues
mattjj Apr 10, 2025
cf268a7
Merge pull request #27895 from jakevdp:random-test-refactor
Google-ML-Automation Apr 10, 2025
382285d
Split JaxTestLoader and related classes into a separate file.
hawkinsp Apr 10, 2025
b4c3e38
When running test cases concurrently, log the start and end of each t…
hawkinsp Apr 10, 2025
f7a2760
Merge pull request #27831 from dfm:linear-call-recursion
Google-ML-Automation Apr 10, 2025
8f9f1aa
add sphinx extension and placeholder config docs rst
ZacCranko Apr 8, 2025
e1aa83a
Add JVP rule for linear_call.
dfm Apr 8, 2025
91d1434
Update XLA dependency to use revision
Google-ML-Automation Apr 10, 2025
ec59178
[Pallas:MGPU] Make sure to await all arrivals on consumed barriers
apaszke Apr 10, 2025
95f1207
Merge pull request #27843 from dfm:lin-call-jvp
Google-ML-Automation Apr 10, 2025
160bbe1
Fix shard_map docs build
yashk2810 Apr 10, 2025
7e2148b
[Pallas:MGPU] Don't assume we'll be running at least max_concurrent_s…
apaszke Apr 10, 2025
e287c7f
Minor adjustments in error messages in launch_context.py
Google-ML-Automation Apr 10, 2025
6dd576a
Add unit tests for the grouped query attention reference implementation
Google-ML-Automation Apr 10, 2025
c730bbd
fix bug in `export_module` when no mesh axes are empty for shardy.
liepieshov Apr 10, 2025
dd050f5
Unify markdown formatting (no visible change on GitHub).
arnoegw Apr 10, 2025
ed05bf8
Add a note about rotation direction for the tpu::RotateOp.
Google-ML-Automation Apr 10, 2025
9af0c05
[export] Add test that exporting works for experimental.compute_on.
gnecula Apr 9, 2025
6c0ac7a
Do a pvary in dynamic_slice_transpose_rule so that the `zeros` are va…
yashk2810 Apr 10, 2025
9011d66
Merge pull request #27903 from mattjj:pvary-errors
Google-ML-Automation Apr 10, 2025
a39a81a
Keep old scale_matmul arg names
kaixih Apr 10, 2025
0f29716
One alias one
kaixih Apr 10, 2025
2090dad
Deprecation warning
kaixih Apr 10, 2025
f3115d3
Fix dtype failures in JaxGroupedQueryAttentionReferenceTest.
dfm Apr 10, 2025
dc33db3
Skip Read the Docs builds unless the 'documentation' label is added.
dfm Apr 10, 2025
16ffbca
Merge pull request #27849 from ZacCranko:docfig
Google-ML-Automation Apr 10, 2025
9f7507f
Run notebooks as part of docs presubmit.
dfm Mar 24, 2025
349605c
Merge pull request #27917 from dfm:rtds-opt-in-label
Google-ML-Automation Apr 10, 2025
8482b7f
Merge pull request #27368 from dfm:docs-on-actions
Google-ML-Automation Apr 10, 2025
5f5e742
Mark as thread-unsafe tests that modify possibly-cached jaxprs in-place.
dougalm Apr 10, 2025
edc76c7
Add documentation for JAX's CI folder
nitins17 Apr 10, 2025
a940100
Enable execution of explicit-sharding notebook in docs.
dfm Apr 10, 2025
ae29f63
Don't use default quant config
kaixih Apr 10, 2025
64e10ad
Merge pull request #27924 from dfm:explicit-sharding-tutorial
Google-ML-Automation Apr 10, 2025
7117aa0
[Mosaic GPU] Skip WGMMA with cluster example on non H100 GPUs.
justinjfu Apr 10, 2025
2807ae4
[Pallas] Fix ()-shaped vectors being materialized in Pallas lowering.
justinjfu Apr 10, 2025
654b91b
Fix grep for label on Read the Docs.
dfm Apr 10, 2025
7e5966b
Make sure direct-linearize handles res_names correctly post vma in ty…
yashk2810 Apr 10, 2025
48e14dc
Implement mutation by replacing the contents of a jax.Array with a re…
pschuh Apr 10, 2025
92be510
[Mosaic GPU] Implement warp-level thread semantics.
justinjfu Apr 10, 2025
cf8a524
Update test shardings.
hawkinsp Apr 10, 2025
3864c4f
Allow ctrl-c to cancel block_until_ready().
pschuh Apr 10, 2025
59068ae
Remove unused jaxlib_mlir_capi targets.
hawkinsp Apr 10, 2025
41a8805
[pallas:mgpu] Return types allowed in mgpu.inline_mgpu.
cperivol Apr 10, 2025
b73bf1a
Update JAX continuous workflow to run once every 3 hours instead of 2.
nitins17 Apr 10, 2025
b352763
Fix Pallas tests so they work with JAX_TEST_NUM_THREADS >= 1.
hawkinsp Apr 10, 2025
6d57f00
[Mosaic:TPU][Relayout] Add implicit 2nd minor
tlongeri Apr 11, 2025
6e52b1e
optimize while_loop by moving readonly carry components to be consts
mattjj Apr 11, 2025
907725d
Merge pull request #27937 from mattjj:while-readonly-carry-optimization
Google-ML-Automation Apr 11, 2025
ffc33ab
Bump scipy build requirement on Python 3.13.
hawkinsp Apr 11, 2025
c5d6a19
Merge pull request #27938 from hawkinsp:scipy
Google-ML-Automation Apr 11, 2025
9f5f6ed
[Pallas] Fix integer array indexing
ayaka14732 Apr 11, 2025
7b7d36a
Add a 2D test in memories_test.
Google-ML-Automation Apr 11, 2025
d42d2e8
[Pallas] Interpret dimensions with parallel semantics by traversing t…
Google-ML-Automation Apr 11, 2025
96d38a6
[cache_misses] Skip tracing-cache-miss explanations for JAX internal …
gnecula Apr 10, 2025
8172220
Remove legacy CPU custom call kernels that have been unused since v0.…
dfm Apr 11, 2025
ac285a1
Merge pull request #27685 from Cjkkkk:return_cudnn_sdpa_residual
Google-ML-Automation Apr 11, 2025
1035c9a
Merge pull request #27916 from gnecula:tracing_cache_ignore_internals
Google-ML-Automation Apr 11, 2025
c9cbf82
Merge pull request #27876 from gnecula:aot_compute_on
Google-ML-Automation Apr 11, 2025
7eb397d
Make `trace` and `lower` class attributes for `jax.jit`.
gnecula Apr 9, 2025
a1c06fc
Merge pull request #27873 from gnecula:aot_wraps2
Google-ML-Automation Apr 11, 2025
896557f
Register NVPTX LLVM backend from Mosaic custom call
beckerhe Apr 11, 2025
b49972d
Move test skip for unary_ops_accuracy_test to a setUp method.
hawkinsp Apr 11, 2025
8082186
Fix api_test on persistent cache enabled platform
gnecula Apr 11, 2025
614ef37
Fix test flakiness in tpu_pallas_test when JAX_TEST_NUM_THREADS > 1.
hawkinsp Apr 11, 2025
d543df1
[pallas:mosaic_gpu] Added support for `unroll=True` to the `lax.fori_…
superbobry Apr 11, 2025
b3c0ec0
Update XLA dependency to use revision
Google-ML-Automation Apr 11, 2025
8b7319a
[JAX] Remove calls to jax.dlpack.to_dlpack(), and avoid passing DLPac…
hawkinsp Apr 11, 2025
3736e5b
Bump the JAX version to v0.6.0, which will be the next release version.
hawkinsp Apr 11, 2025
5adac1c
Fix the printing of the function name in tracing-cache-miss explanations
gnecula Apr 11, 2025
b1c96d4
Remove unused execute_sharded_* functions.
pschuh Apr 11, 2025
a39b623
Make sure the order passed to `make_jit` and `_parse_jit_arguments` i…
yashk2810 Apr 11, 2025
ab88273
Deprecate jax.dlpack.to_dlpack.
hawkinsp Apr 11, 2025
8e9fca1
document SPMD pipeline parallelism
Google-ML-Automation Apr 11, 2025
5cf74cc
Use dash instead of underscore for extras.
nitins17 Apr 11, 2025
27c07f7
[Pallas] Allow 1D iota
justinjfu Apr 11, 2025
904419c
Rename TPU bazel test tags.
hawkinsp Apr 11, 2025
e9364f4
Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7
mattjj Apr 11, 2025
6efcf44
Deprecate `PositionalSharding` and `GSPMDSharding`
yashk2810 Apr 11, 2025
c0d97a6
Removed type annotations appear to be used and actually defined in py…
pschuh Apr 11, 2025
c90751b
Fix typo in jax.lax.linalg.symmetric_product description
ywrt Apr 11, 2025
6fc78a5
Deprecate jax.lax.infeed and jax.lax.outfeed.
hawkinsp Apr 11, 2025
b2a8df7
Add the `method` argument to `jax.numpy.isin` stub.
Google-ML-Automation Apr 11, 2025
b3f49e4
Re-landing #27937 with fewer bugs and more tests.
mattjj Apr 11, 2025
0fa732e
[ragged-paged-attn][NFC] Make validate_inputs functions take same inp…
bythew3i Apr 11, 2025
29f65f0
re-index jaxpr input effects in move_binders_to_front
mattjj Apr 11, 2025
1a4a86a
Merge pull request #27970 from mattjj:while-readonly-carry-optimization
Google-ML-Automation Apr 12, 2025
8afc833
Rename is_closed to is_open in the shardy shardings
yashk2810 Apr 12, 2025
e1cad34
Add `ChunkedCausalMask` for Splash Attention to support attention mas…
Google-ML-Automation Apr 12, 2025
dc10200
[explain-cache-miss] Improve the detection of user file names
gnecula Apr 12, 2025
19d3d95
unify `stages.Executable` and `stages.XlaExecutable`
froystig Apr 12, 2025
4ff78e6
Remove various methods from `MeshExecutable`
yashk2810 Apr 12, 2025
99ca146
revert making `Executable` an ABC
froystig Apr 12, 2025
566d077
unify `stages.Lowering` and `stages.XlaLowering`
froystig Apr 12, 2025
c69e61e
Remove jax.lib.xla_client.{XlaComputation,Shape}.
hawkinsp Apr 12, 2025
69173a2
Update XLA dependency to use revision
Google-ML-Automation Apr 12, 2025
ca50cae
Properly center and size the SM image in the GPU docs
apaszke Apr 9, 2025
a51307a
Merge pull request #27981 from apaszke:mgpu-sm-image
Google-ML-Automation Apr 13, 2025
7edd5d5
Add reference docs for Pallas:MGPU synchronization primitives
apaszke Apr 8, 2025
4fd610f
Update XLA dependency to use revision
Google-ML-Automation Apr 13, 2025
773b323
Merge pull request #27868 from apaszke:mgpu-synchronization-docs
Google-ML-Automation Apr 13, 2025
f070cde
[explain-cache-miss] Improve tracing-cache-miss explanations
gnecula Apr 10, 2025
2336cd1
Minor improvements to doc for jax.nn.logsumexp.
carlosgmartin Apr 13, 2025
2e4c0ec
[Mosaic:TPU] Add some invariant checking in VectorLayout ctor
tlongeri Apr 14, 2025
13c7183
add a brief description of the jax.Array-has-no-__iadd__ gotcha
mattjj Apr 11, 2025
1af747b
Merge pull request #27973 from mattjj:iadd-gotcha
Google-ML-Automation Apr 14, 2025
6ca623f
Merge pull request #27980 from gnecula:tracing_cache
Google-ML-Automation Apr 14, 2025
3a7cec8
Add Pallas:MGPU documentation for WGMMA
apaszke Apr 13, 2025
ec8c065
Merge pull request #27982 from apaszke:mgpu-wgmma-docs
Google-ML-Automation Apr 14, 2025
b8df474
[explain_cache_miss] Add to explanations the duration of the missed f…
gnecula Apr 14, 2025
95e0c2a
Merge pull request #27925 from dfm:debug-rtds
Google-ML-Automation Apr 14, 2025
11a6abc
Remove accidental tab characters from Pallas:MGPU docs
apaszke Apr 14, 2025
9abc74d
Merge pull request #27997 from apaszke:fix-tabs
Google-ML-Automation Apr 14, 2025
b6c6c1c
Merge pull request #27971 from ywrt:patch-1
Google-ML-Automation Apr 14, 2025
077d134
Adjust test expectations for the tracing-cache-miss-explanations
gnecula Apr 14, 2025
1b1bd07
Finalize deprecation of vectorized argument in callbacks.
dfm Apr 14, 2025
5af5925
Update XLA dependency to use revision
Google-ML-Automation Apr 14, 2025
ceca6ec
jax.jit: deprecate non-standard call signature.
jakevdp Apr 11, 2025
30669dc
Merge pull request #27993 from gnecula:explain_timing
Google-ML-Automation Apr 14, 2025
42542fe
jnp.power: better docs for invalid input
jakevdp Apr 14, 2025
785d077
Disable unknown warning option error on Mac.
nitins17 Apr 14, 2025
6fcb036
Merge pull request #27966 from jakevdp:jit-signature
Google-ML-Automation Apr 14, 2025
d014912
Merge pull request #28007 from jakevdp:int-power
Google-ML-Automation Apr 14, 2025
1fcb2b4
Reinstate lifegiving chaos line to docs.
emilyfertig Apr 14, 2025
8930a67
Fix stablehlo version comparison in test utilities.
hawkinsp Apr 14, 2025
73305e0
Update issue template with correct URL for untemplated issue
jakevdp Apr 14, 2025
8bae29d
Merge pull request #28020 from jakevdp:issue-template
Google-ML-Automation Apr 14, 2025
19be20f
Merge pull request #27919 from kaixih:enable_doc_scaled_dot_fix
Google-ML-Automation Apr 14, 2025
a64e7dc
Merge pull request #28012 from emilyfertig:emilyaf-random-docs-line
Google-ML-Automation Apr 14, 2025
a88486c
Fix warnings in array_interoperability_test.
hawkinsp Apr 14, 2025
ab600c3
Remove obsolete python key path registry.
IvyZX Apr 14, 2025
7b4b2f4
Fixed the way to skip tests using optional python dependencies
vfdev-5 Apr 14, 2025
57e33bc
Deprecate the contents of jax.util.
hawkinsp Apr 15, 2025
69d21c6
Merge pull request #27999 from vfdev-5:fix-skip-test-pattern-with-opt…
Google-ML-Automation Apr 15, 2025
4fa3cd9
[Pallas/Fuser] Add basic closed over consts support to pull_block_spec
sharadmv Apr 15, 2025
0ed0fb7
Adds a debugging message to assert, otherwise the error is pretty cry…
marksandler2 Apr 15, 2025
1926b99
[pallas] Fix spelling of 'fusible'.
chr1sj0nes Apr 15, 2025
09edc49
Add explicit_axes section to the doc
yashk2810 Apr 15, 2025
aed3297
[CI] Update GPU optional presubmit naming
MichaelHudgins Apr 15, 2025
06a77b7
[CI] Propagate halt connection to tpu tests
MichaelHudgins Apr 15, 2025
0b04739
Deprecate the remaining exports from jax.lib.xla_client.
hawkinsp Apr 15, 2025
4d692d1
Update XLA dependency to use revision
Google-ML-Automation Apr 15, 2025
c56cf4f
jax.random.bernoulli: use higher-resolution sampler
jakevdp Apr 15, 2025
3b359ba
Merge pull request #28027 from jax-ml:explicit_axes
Google-ML-Automation Apr 15, 2025
34c2dbf
Fix lowering code for ROCm RNN
Ruturaj4 Mar 20, 2025
87e4b5f
Merge pull request #28036 from ROCm:ci_fix_rnn_lowering-upstream
Google-ML-Automation Apr 15, 2025
b336daf
Merge pull request #28022 from jakevdp:bernoulli-fix
Google-ML-Automation Apr 15, 2025
c527ddb
[Mosaic:TPU] Fix bug in `rotateVregRows`
tlongeri Apr 15, 2025
6e00b5e
[NFC] Rename `standard_insert_pbroadcast` to `standard_insert_pvary`
yashk2810 Apr 15, 2025
ba88777
Roll back https://github.com/jax-ml/jax/pull/28022 due to test breaka…
Apr 15, 2025
393c555
Fix bugs in tp_traverse handlers.
hawkinsp Apr 15, 2025
b271a67
Clean up softmax initial deprecation
jakevdp Apr 15, 2025
7388913
[ragged-paged-attn][NFC] Set kv_pages_per_blk uplimit.
bythew3i Apr 15, 2025
655bfca
Enable standard_insert_pvary for optimization_barrier which was disab…
yashk2810 Apr 15, 2025
002be7a
Merge pull request #28047 from jakevdp:logsoftmax-dep
Google-ML-Automation Apr 15, 2025
25e0fe5
Merge pull request #27984 from carlosgmartin:logsumexp_doc
Google-ML-Automation Apr 15, 2025
47bc2f5
convert NumPy RNG key data to uncommitted default-device-backed `jax.…
froystig Apr 16, 2025
90af597
remove inaccurate inline comment in `PRNGKeyArray` constructor
froystig Apr 16, 2025
2beff6a
[pallas] Fix case of `Fusible{ElementDtype,TyRules}`.
chr1sj0nes Apr 16, 2025
5520055
Fix test flakyness by blocking until the data is ready.
pschuh Apr 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
24 changes: 16 additions & 8 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ build -c opt
build --output_filter=DONT_MATCH_ANYTHING

build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
build --copt=-DNB_DOMAIN=jax

# #############################################################################
# Platform Specific configs below. These are automatically picked up by Bazel
Expand Down Expand Up @@ -97,6 +98,7 @@ build:windows --incompatible_strict_action_env=true
# #############################################################################
build:nonccl --define=no_nccl_support=true

build --repo_env USE_PYWRAP_RULES=1
build:posix --copt=-fvisibility=hidden
build:posix --copt=-Wno-sign-compare
build:posix --cxxopt=-std=c++17
Expand Down Expand Up @@ -130,19 +132,21 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
build:clang --copt=-Qunused-arguments
# Error on struct/class mismatches, since this causes link failures on Windows.
build:clang --copt=-Werror=mismatched-tags
# Required when building with clang>=19, see jax-ml/jax#27091
build:clang --copt=-Wno-error=c23-extensions

# Configs for CUDA
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda

# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This config is used for building targets with CUDA libraries from stubs.
Expand Down Expand Up @@ -253,15 +257,19 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm

# Mac Arm64 CI configs
build:ci_darwin_arm64 --macos_minimum_os=11.0
# Clang 19 requires `-Wno-error=c23-extensions` but this flag is not supported
# on Apple Clang in XCode 16.0 so we suppress unknown warning option errors
# on Mac CI builds.
build:ci_darwin_arm64 --copt=-Wno-unknown-warning-option
build:ci_darwin_arm64 --config=macos_cache_push
build:ci_darwin_arm64 --verbose_failures=true
build:ci_darwin_arm64 --color=yes

# Windows x86 CI configs
build:ci_windows_amd64 --config=avx_windows
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
build:ci_windows_amd64 --color=yes

Expand Down Expand Up @@ -329,9 +337,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst
build:rbe_windows_amd64 --config=rbe

# Set the host, execution, and target platform
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"

build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
build:rbe_windows_amd64 --enable_runfiles
Expand Down
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ body:

[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues

[Raw report]: http://github.com/jax-ml/jax/issues/new
[Raw report]: https://github.com/jax-ml/jax/issues/new?template=none
- type: textarea
attributes:
label: Description
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/bazel_cuda_non_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest"

env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
Expand Down Expand Up @@ -79,6 +79,7 @@ jobs:
continue-on-error: true
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
Expand Down
62 changes: 62 additions & 0 deletions .github/workflows/bazel_optional_b200.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: CI - Bazel Optional B200 CUDA tests
on:
# Runs on PR if label "CI Optional GPU Presubmit" is present.
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
pull_request:
branches:
- main
types: [ labeled, synchronize ]
schedule:
- cron: "0 */2 * * *" # Run once every 2 hours
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
run_tests:
if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }}
runs-on: linux-x86-a4-224-b200-1gpu
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest'
name: "Bazel single B200 CUDA tests"
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA Tests
run: |
nvidia-smi
bazel test --config=ci_linux_x86_64_cuda \
--config=resultstore \
--config=rbe_cache \
--repo_env=HERMETIC_CUDA_VERSION="12.8.0" \
--repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \
--repo_env=HERMETIC_PYTHON_VERSION="3.13" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--run_under "$(pwd)/build/parallel_accelerator_execute.sh" \
--test_output=errors \
--test_env=JAX_ACCELERATOR_COUNT=1 \
--test_env=JAX_TESTS_PER_ACCELERATOR=32 \
--local_test_jobs=32 \
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
--test_tag_filters=-multiaccelerator \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64="1" \
--action_env=NCCL_DEBUG=WARN \
--color=yes \
//tests:gpu_tests //tests:backend_independent_tests \
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
12 changes: 6 additions & 6 deletions .github/workflows/build_artifacts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ on:
default: "linux-x86-n2-16"
options:
- "linux-x86-n2-16"
- "linux-arm64-c4a-64"
- "windows-x86-n2-64"
- "linux-arm64-t2a-48"
- "windows-x86-n2-16"
artifact:
description: "Which JAX artifact to build?"
type: choice
Expand Down Expand Up @@ -119,11 +119,11 @@ jobs:

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Enable RBE if building on Linux x86
if: contains(inputs.runner, 'linux-x86')
- name: Enable RBE if building on Linux x86 or Windows x86
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86
if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86')
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64
if: contains(inputs.runner, 'linux-arm64')
run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV
# Halt for testing
- name: Wait For Connection
Expand Down
18 changes: 9 additions & 9 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python 3.11
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: 3.11
- run: python -m pip install pre-commit
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
- uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
Expand Down Expand Up @@ -70,7 +70,7 @@ jobs:
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -140,9 +140,9 @@ jobs:
- name: Image Setup
run: |
apt update
apt install -y libssl-dev libsqlite3-dev
apt install -y libssl-dev libsqlite3-dev build-essential
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -151,7 +151,7 @@ jobs:
uv pip install --system -r docs/requirements.txt
- name: Render documentation
run: |
sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
sphinx-build -j auto --color -W --keep-going -b html docs docs/build/html

jax2tf_test:
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
Expand All @@ -168,7 +168,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -201,7 +201,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: 3.12
- name: Install JAX
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ jobs:
matrix:
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/cloud-tpu-ci-presubmit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ jobs:
python: "3.10"
libtpu-version-type: "nightly"
gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }}
halt-for-connection: ${{ inputs.halt-for-connection || false }}
# End Presubmit Naming Check github-tpu-presubmits
34 changes: 34 additions & 0 deletions .github/workflows/community_release_actions.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Release Actions

on:
release:
types: [published]

permissions:
contents: read

jobs:
discord_release:
if: github.repository_owner == 'jax-ml'
runs-on: ubuntu-latest
steps:
- name: Get release URL
id: get-release-url
run: |
URL="https://docs.jax.dev/en/latest/changelog.html"
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1
id: get-content
with:
stringToTruncate: |
JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!

${{ github.event.release.body }}
maxLength: 2000
truncationSymbol: "..."
- name: Discord Webhook Action
uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}
4 changes: 2 additions & 2 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04
ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
Loading
Loading