forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
Ci upstream sync 168 1 reb #357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- `dimensions().size()` if it's OK for the result to be changed to an unsigned number, - `dimensions_size()` if it's important that the result is a signed number. This should be a pure refactoring that doesn't affect the code's behavior. Note that `rank()` returns `int64_t` and `dimensions().size()` returns `size_t`. Sometimes the change of the signedness is not desirable, and we use `dimensions_size()`, which returns `int`, in such cases. PiperOrigin-RevId: 741524661
…lt vector size in layout inference. PiperOrigin-RevId: 741528085
- ShardingInTypesTest.test_set_mesh - APITest.test_cache_clear_pmap This helps to prevent errors like: 1) in pjit_test.py: ``` ValueError: For primitive mul, context mesh AbstractMesh('x': 2, axis_types=(Explicit,)) should match the aval mesh AbstractMesh('x': 2, 'y': 1, axis_types=(Auto, Auto)) for shape float32[8,2] ``` raised for example by ArrayPjitTest.test_pjit_array_multi_input_multi_output_mesh3 and also by ArrayPjitTest.test_convert_element_type_sharding, when pjit tests are run concurrently with `--local_test_jobs=32` and `--test_env=JAX_TEST_NUM_THREADS=8` 2) in api_test.py ``` AssertionError: Expected exactly 1 XLA compilations, but executed 2 ``` raised by APITest.test_pmap_global_cache.
PiperOrigin-RevId: 741529177
PiperOrigin-RevId: 741534342
Also, switch the Linux aarch64 runner type to t2a as we run the tests on t2a. PiperOrigin-RevId: 741538543
The function itself was already deleted. PiperOrigin-RevId: 741546212
call_tf is the only remaining user of the XlaComputation type in JAX. Change it to use a new helper function that converts an HLO proto to stablehlo bytecode without using the XlaComputation Python bindings. Also port the code to parse types from the stablehlo rather than the HLO. Remove jax.interpreters.mlir.xla_computation_to_mlir_module. PiperOrigin-RevId: 741548298
It is clearer to use a flag to indicate the first step than to use a step counter == 0, since in theory the step counter (a 32 bit integer in the code) can wrap around back to zero, even though this will unlikely happen since there are way less than 2**32 blocks. PiperOrigin-RevId: 741551623
…easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though. PiperOrigin-RevId: 741554623
PiperOrigin-RevId: 741558343
No code functionality change in this commit. PiperOrigin-RevId: 741566312
Specialize it to one shape per aval, since that's the only case that exists. Remove some pointless assertions using this code. PiperOrigin-RevId: 741569024
…ured way This is hopefully less confusing then bunching them together in a single argument. PiperOrigin-RevId: 741580827
… lane-level lowering More work is needed to support these in the WG lowering. PiperOrigin-RevId: 741622096
This is a precautionary measure to prevent conflicts with other packages using nanobind and registering the same types. We don't want JAX's nanobind registrations to conflict on, say, XLA types with other projects.
PiperOrigin-RevId: 741644574
…sired dtype back to HBM. We use f32 as the dtype inside the kernel. Before we write the result from vmem to hbm, we convert to the desired dtype (eg bf16). So we can save memory bandwidth. Also, made minor change by checking sliding window and logit soft capping in the function that checks the static value. PiperOrigin-RevId: 741660728
… primitives PiperOrigin-RevId: 741661360
Imported from GitHub PR jax-ml#27576 This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated. This PR also includes some fixes for getattr/setattr. Copybara import of the project: -- 3b1ea1a by Matthew Johnson <[email protected]>: [attrs] experimental appendattr Merging this change closes jax-ml#27576 COPYBARA_INTEGRATE_REVIEW=jax-ml#27576 from mattjj:appendattr b937952 PiperOrigin-RevId: 741662724
The GPU-specific deps were added to the backend-independent tests by mistake [here](jax-ml#27113). These tests should pass using `jax` and `jaxlib` wheels only. PiperOrigin-RevId: 741663266
Update community_release_actions.yml
PiperOrigin-RevId: 741685454
…rging dim was `1`. PiperOrigin-RevId: 741740811
http://github.com/openxla/xla/commit/f50746ab3144d0bf59c8e5c2dcfb2e09e56338d0. PiperOrigin-RevId: 741809075
So duplicated load/store ops can be removed. PiperOrigin-RevId: 745209849
None is meant to represent the same thing as {replicated over all possible axes}. But without this canonicalization, we could compare None as not equal to {all possible axes}. fixes jax-ml#26621 Unrelated: in several places, including the _check_rep path, we don't handle partial auto correctly, since we treat {all possible axes} as {all mesh axes}, but actually it should be more like {all mesh axes} - auto. We'll leave that fix for a follow-up...
See jax-ml#18711 check_rep uses rep=None to indicate when an argument is a constant, and that's useful specifically when checking the backward pass for integer_pow, which has a multiplication by a constant that didn't get a pbroadcast applied to it. That is, we use rep=None as a special carve-out for constants. The standard rules were compatible with rep=None, but the rules for higher-order primitives like scan and cond were not. So we had to upgrade them.
Pass pytype_srcs as data to the pybind_extension rule. PiperOrigin-RevId: 745238783
PiperOrigin-RevId: 745247778
These should be used directly from ml_dtypes. PiperOrigin-RevId: 745256523
PiperOrigin-RevId: 745261995
Now that jax-ml@db11efa has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA. There's no reason weakref_lru_cache is in the same Python extension as everything else. PiperOrigin-RevId: 745271825
…es` to True. The main changes here are: * Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead. * Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`. * Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`. * Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on. Co-authored-by: Matthew Johnson <[email protected]> PiperOrigin-RevId: 745276474
I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B. PiperOrigin-RevId: 745294058
This (private) API will shortly be deleted, and hlo_to_stablehlo is its replacement. PiperOrigin-RevId: 745333506
…JAX_SKIP_SLOW_TESTS=true Description: - Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython - Removed optional deps for 3.14
PiperOrigin-RevId: 745342103
…houldn't expose this to public API and have users use `psum` instead which will dispatch to `psum_invariant` when `check_rep=True`. PiperOrigin-RevId: 745352875
PiperOrigin-RevId: 745375892
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.