@@ -968,15 +968,15 @@ def shard_map_error_check(
968
968
new_in_names = (* ([{}] * num_error_vals ), * in_names )
969
969
new_vals_in = [* err_vals , * vals_in ]
970
970
in_avals = list (map (core .get_aval , new_vals_in ))
971
- auto = kwargs .get ('auto ' )
971
+ manual_axes = kwargs .get ('manual_axes ' )
972
972
check_vma = kwargs .get ('check_vma' )
973
973
for i , v in enumerate (in_avals ):
974
974
if not (sharder := core .shard_aval_handlers .get (type (v ))):
975
975
raise ValueError (f'Unsupported aval type: { type (v )} ' )
976
- in_avals [i ] = sharder (mesh , auto , check_vma , new_in_names [i ], v )
976
+ in_avals [i ] = sharder (mesh , manual_axes , check_vma , new_in_names [i ], v )
977
977
978
- with (jshmap ._extend_axis_env (mesh , auto ),
979
- mesh_lib .use_abstract_mesh (jshmap ._as_manual_mesh (mesh , auto )),
978
+ with (jshmap ._extend_axis_env (mesh , manual_axes ),
979
+ mesh_lib .use_abstract_mesh (jshmap ._as_manual_mesh (mesh , manual_axes )),
980
980
config ._check_vma (check_vma )):
981
981
# jaxpr to checked_jaxpr
982
982
checked_jaxpr , out_tree , _ = jaxpr_to_checkify_jaxpr (
0 commit comments