Skip to content

Commit 66b6eaa

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Replace auto tracking with manual axis names tracking internally in shard_map.
This will make it easier to implement nested shard_map where you enter into manual mode one axis at a time per shard_map. PiperOrigin-RevId: 751509665
1 parent ae2d35a commit 66b6eaa

File tree

3 files changed

+146
-139
lines changed

3 files changed

+146
-139
lines changed

jax/_src/checkify.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -968,15 +968,15 @@ def shard_map_error_check(
968968
new_in_names = (*([{}] * num_error_vals), *in_names)
969969
new_vals_in = [*err_vals, *vals_in]
970970
in_avals = list(map(core.get_aval, new_vals_in))
971-
auto = kwargs.get('auto')
971+
manual_axes = kwargs.get('manual_axes')
972972
check_vma = kwargs.get('check_vma')
973973
for i, v in enumerate(in_avals):
974974
if not (sharder := core.shard_aval_handlers.get(type(v))):
975975
raise ValueError(f'Unsupported aval type: {type(v)}')
976-
in_avals[i] = sharder(mesh, auto, check_vma, new_in_names[i], v)
976+
in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_names[i], v)
977977

978-
with (jshmap._extend_axis_env(mesh, auto),
979-
mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, auto)),
978+
with (jshmap._extend_axis_env(mesh, manual_axes),
979+
mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)),
980980
config._check_vma(check_vma)):
981981
# jaxpr to checked_jaxpr
982982
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(

jax/_src/interpreters/pxla.py

-2
Original file line numberDiff line numberDiff line change
@@ -3005,8 +3005,6 @@ def from_hlo(name: str,
30053005
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
30063006
compiler_options_kvs, pgle_profiler)
30073007

3008-
orig_out_shardings = out_shardings
3009-
30103008
if auto_spmd_lowering:
30113009
assert mesh is not None
30123010
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(

0 commit comments

Comments
 (0)