diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index a645f6c71249..c8caffeb7877 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -968,15 +968,15 @@ def shard_map_error_check( new_in_names = (*([{}] * num_error_vals), *in_names) new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) - auto = kwargs.get('auto') + manual_axes = kwargs.get('manual_axes') check_vma = kwargs.get('check_vma') for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') - in_avals[i] = sharder(mesh, auto, check_vma, new_in_names[i], v) + in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_names[i], v) - with (jshmap._extend_axis_env(mesh, auto), - mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, auto)), + with (jshmap._extend_axis_env(mesh, manual_axes), + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f337c996d3e3..6c9c6d2aad68 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3005,8 +3005,6 @@ def from_hlo(name: str, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, compiler_options_kvs, pgle_profiler) - orig_out_shardings = out_shardings - if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index fce7ed824cb0..60e1991e2d1b 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -153,10 +153,10 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, axis_names = frozenset(axis_names) if not axis_names: axis_names = frozenset(mesh.axis_names) - auto = frozenset(mesh.axis_names) - frozenset(axis_names) - if not auto.issubset(mesh.axis_names): - raise ValueError(f"shard_map requires auto={auto} to be a subset of " - f"mesh.axis_names={mesh.axis_names}") + if not axis_names.issubset(mesh.axis_names): + raise ValueError( + f"jax.shard_map requires axis_names={axis_names} to be a subset of " + f"mesh.axis_names={mesh.axis_names}") # TODO(yashkatariya): Maybe we don't have to be this strict? if mesh._any_axis_auto_or_manual and in_specs is None: @@ -166,9 +166,9 @@ def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, f" {mesh} has `Auto` axes.\n") if in_specs is not None: - _check_specs(SpecErrorType.input, in_specs, auto) + _check_specs(SpecErrorType.input, in_specs, axis_names) if not callable(out_specs): - _check_specs(SpecErrorType.out, out_specs, auto) + _check_specs(SpecErrorType.out, out_specs, axis_names) @util.wraps(f) @traceback_util.api_boundary @@ -202,7 +202,7 @@ def wrapped(*args): def out_names_thunk(): if callable(out_specs): out_specs_ = out_specs() - _check_specs(SpecErrorType.out, out_specs_, auto) + _check_specs(SpecErrorType.out, out_specs_, axis_names) else: out_specs_ = out_specs dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) @@ -219,7 +219,8 @@ def out_names_thunk(): try: out_flat = shard_map_p.bind( fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_vma=check_vma, auto=auto) + out_names_thunk=out_names_thunk, check_vma=check_vma, + manual_axes=axis_names) except _SpecError as e: fails, = e.args if not callable(out_specs): @@ -268,38 +269,39 @@ def _manual_spec(manual_axes, spec: P) -> P: SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) -def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None: +def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None: if error_type == SpecErrorType.input and specs is None: raise TypeError( "shard_map in_specs argument must be a pytree of " "`jax.sharding.PartitionSpec` instances, but it was None.\n" "Instead of `in_specs=None`, did you mean `in_specs=P()`, " "where `P = jax.sharding.PartitionSpec`?") + def check_spec(p): if not isinstance(p, PartitionSpec): return False for names in p: - if not isinstance(names, tuple): - names = (names,) + names = (names,) if not isinstance(names, tuple) else names for name in names: - if name in auto: + if name is not None and name not in manual_axes: return False return True - if all(check_spec(p) for p in tree_leaves(specs)): return + + if all(check_spec(p) for p in tree_leaves(specs)): + return prefix = 'in' if error_type == SpecErrorType.input else 'out' msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " for key, x in generate_key_paths(specs) if not isinstance(x, P)] if not msgs: for key, p in generate_key_paths(specs): for names in p: - if not isinstance(names, tuple): - names = (names,) + names = (names,) if not isinstance(names, tuple) else names for name in names: - if name in auto: + if name is not None and name not in manual_axes: msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") raise ValueError( - f"shard_map {prefix}_specs argument cannot refer to an axis " - f"marked auto ({auto}), but:\n\n" + f"shard_map {prefix}_specs argument must refer to an axis " + f"marked as manual ({manual_axes}), but:\n\n" + '\n\n'.join(msgs) + '\n\n' f"Check the {prefix}_specs values passed to shard_map.") raise TypeError( @@ -527,13 +529,13 @@ def get_bind_params(self, params): # Staging @util.cache(max_size=256, trace_context_in_key=True) -def _as_manual_mesh(mesh, auto: frozenset): - manual_axes = tuple(set(mesh.axis_names) - auto) +def _as_manual_mesh(mesh, manual_axes: frozenset): + not_manual = set(mesh.axis_names) - manual_axes cur_mesh = get_abstract_mesh() if cur_mesh.empty: cur_mesh = mesh explicit_axes, auto_axes = set(), set() # type: ignore - for a in auto: + for a in not_manual: if cur_mesh._name_to_type[a] == AxisType.Auto: auto_axes.add(a) else: @@ -553,9 +555,9 @@ def _as_manual_mesh(mesh, auto: frozenset): axis_types=tuple(new_axis_types)) -def _extend_axis_env(mesh, auto): +def _extend_axis_env(mesh, manual_axes): return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() - if k not in auto]) + if k in manual_axes]) def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, @@ -563,22 +565,22 @@ def _shard_map_staging( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_vma: bool, - auto: frozenset, + manual_axes: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: source_info = source_info_util.current() to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) in_tracers = map(to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, auto, check_vma), in_names, + in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_names, in_avals) - manual_mesh = _as_manual_mesh(mesh, auto) - with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + manual_mesh = _as_manual_mesh(mesh, manual_axes) + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_vma: out_vma = [v.aval.vma for v in jaxpr.outvars] - _check_reps(mesh, auto, out_names_thunk(), out_vma) + _check_reps(mesh, out_names_thunk(), out_vma) out_avals = map(_check_shapedarray, out_avals_) out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] @@ -587,12 +589,12 @@ def _shard_map_staging( constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_vma=check_vma, auto=auto) + check_vma=check_vma, manual_axes=manual_axes) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, effs, source_info) @@ -606,11 +608,11 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval -def _shard_aval(mesh: Mesh, auto, check_vma, names: AxisNames, +def _shard_aval(mesh: Mesh, manual_axes, check_vma, names: AxisNames, aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.shard_aval_handlers: - return core.shard_aval_handlers[type(aval)](mesh, auto, check_vma, names, - aval) + return core.shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma, + names, aval) raise NotImplementedError(f"Unsupported aval type: {type(aval)}") def _unshard_aval(mesh: Mesh, check_vma, names: AxisNames, @@ -620,12 +622,13 @@ def _unshard_aval(mesh: Mesh, check_vma, names: AxisNames, else: raise NotImplementedError(f"Unsupported aval type: {type(aval)}") -def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_vma, names: AxisNames, - aval: core.AbstractValue) -> core.AbstractValue: +def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma, + names: AxisNames, aval: core.AbstractValue + ) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) - manual_mesh = _as_manual_mesh(mesh, auto) + manual_mesh = _as_manual_mesh(mesh, manual_axes) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) vma = (frozenset({n for ns in names.values() for n in ns}) if check_vma else frozenset()) @@ -669,19 +672,19 @@ def _unshard_shaped_array(mesh: Mesh, check_vma, names: AxisNames, RepType = Any def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_vma, auto): + check_vma, manual_axes): # TODO(mattjj,parkers): check auto for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): if not core.typecompat(v.aval, _shard_aval( - mesh, auto, check_vma, in_name, x.aval)): + mesh, manual_axes, check_vma, in_name, x.aval)): raise core.JaxprTypeError("shard_map argument avals not compatible with " "jaxpr binder avals and in_names") - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): core.check_jaxpr(jaxpr) if check_vma: - out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] + out_rep = [_vma_to_rep(mesh, v.aval.vma) for v in jaxpr.outvars] for rep, dst in zip(out_rep, out_names): - if not _valid_repeats(mesh, auto, rep, dst): + if not _valid_repeats(mesh, rep, dst): raise core.JaxprTypeError( "shard_map can't prove output is sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] @@ -692,14 +695,14 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, core.custom_typechecks[shard_map_p] = _shard_map_typecheck -def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: - return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) +def _valid_repeats(mesh: Mesh, rep: RepType, dst: AxisNames) -> bool: + return rep is None or (set(_unmentioned(mesh, dst))).issubset(rep) # Lowering def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in + ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in ) -> sharding_impls.SdyArraySharding: axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) @@ -707,24 +710,23 @@ def _shardy_shard_map_sharding( ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) - if auto: + if len(manual_axes) < len(mesh.axis_names): for dim_sharding in sdy_sharding.dimension_shardings: dim_sharding.is_open = True return sdy_sharding def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_vma): + ctx, in_nodes, jaxpr, mesh, in_names, out_names, manual_axes, check_vma): + axis_ctx = ctx.module_context.axis_context in_avals_ = [v.aval for v in jaxpr.invars] - if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): # Nested `ManualComputationOp`s cannot refer to axes that are already # manual. So figure out what axes are free thus far. - free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes - shardy_manual_axes = free_axes - auto + shardy_manual_axes = frozenset(mesh.axis_names) - axis_ctx.manual_axes else: - shardy_manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) - auto) + shardy_manual_axes = manual_axes + new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) # The order of manual axes should match the order of mesh.axis_names to avoid @@ -733,17 +735,17 @@ def _shard_map_lowering_shardy( if a in shardy_manual_axes] if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, dim_var_values=ctx.dim_var_values) return out_nodes in_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), + partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), in_names, ctx.avals_in)).build() out_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), + partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), out_names, ctx.avals_out)).build() output_types = map(mlir.aval_to_ir_type, ctx.avals_out) manual_computation_op = sdy.ManualComputationOp( @@ -752,7 +754,7 @@ def _shard_map_lowering_shardy( ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) block = ir.Block.create_at_start( manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) - with (ir.InsertionPoint(block), _extend_axis_env(mesh, auto), + with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma)): out_nodes_, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, @@ -763,27 +765,26 @@ def _shard_map_lowering_shardy( def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_vma, auto): + check_vma, manual_axes): if config.use_shardy_partitioner.value: return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_vma) + ctx, in_nodes, jaxpr, mesh, in_names, out_names, manual_axes, check_vma) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] - in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, - in_avals_, in_nodes) - manual_axes = frozenset(mesh.axis_names) - auto + in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_names, + ctx.avals_in, in_avals_, in_nodes) new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes_, tokens_out = mlir.call_lowering( "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) ctx.set_tokens_out(tokens_out) - return map(partial(_xla_unshard, ctx, mesh, auto), out_names, out_avals_, - ctx.avals_out, out_nodes_) + return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_names, + out_avals_, ctx.avals_out, out_nodes_) mlir.register_lowering(shard_map_p, _shard_map_lowering) def _make_scoped_manual_sharding(ctx, mesh, axes): @@ -795,9 +796,9 @@ def _make_scoped_manual_sharding(ctx, mesh, axes): return NamedSharding( mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore -def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, +def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) @@ -805,25 +806,27 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() - unspecified = set(range(aval_in.ndim)) if auto else set() + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, unspecified_dims=unspecified) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) + manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh) return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) -def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, +def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, names, aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_out.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_out, ns) aval_out = core.physical_aval(aval_out) - unspecified = set(range(aval_out.ndim)) if auto else set() + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): aval_in = core.physical_aval(aval_in) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) + manual_proto = pxla.manual_proto(aval_in, manual_axes, mesh) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, @@ -859,12 +862,13 @@ def _vma_to_spec(mesh, vma): def _names_to_vma(names): return {n for ns in names.values() for n in ns} -def _vma_to_rep(mesh, auto, vma): - return frozenset((set(mesh.axis_names) - auto) - vma) +def _vma_to_rep(mesh, vma): + return set(mesh.axis_names) - vma def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_vma, auto): - if auto: raise NotImplementedError + check_vma, manual_axes): + if len(manual_axes) < len(mesh.axis_names): + raise NotImplementedError del prim if isinstance(mesh, AbstractMesh): mesh = get_mesh_from_args(args, mesh) @@ -872,11 +876,12 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh), in_names, args) in_vma = map(_names_to_vma, in_names) - outs, out_vma = _run_shmap(fun, mesh, auto, args, in_vma, check_vma, cur_mesh) + outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma, + cur_mesh) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_vma: - _check_reps(mesh, auto, out_names_thunk(), out_vma) + _check_reps(mesh, out_names_thunk(), out_vma) src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) else: src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) @@ -885,11 +890,11 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, outs) core.EvalTrace.process_shard_map = _shard_map_impl -def _run_shmap(f, mesh, auto, args, vmas, check_vma, context_mesh): - trace = ShardMapTrace(mesh, auto, check_vma, context_mesh) +def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh): + trace = ShardMapTrace(mesh, manual_axes, check_vma, context_mesh) in_tracers = map(partial(ShardMapTracer, trace), vmas, args) - manual_mesh = _as_manual_mesh(mesh, auto) - with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), + manual_mesh = _as_manual_mesh(mesh, manual_axes) + with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes), use_abstract_mesh(manual_mesh), config._check_vma(check_vma)): ans = f.call_wrapped(*in_tracers) outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) @@ -928,9 +933,9 @@ def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] class _SpecError(Exception): pass -def _check_reps(mesh, auto, names, vmas): - reps = [_vma_to_rep(mesh, auto, v) for v in vmas] - fail = [r if not _valid_repeats(mesh, auto, r, n) else no_fail +def _check_reps(mesh, names, vmas): + reps = [_vma_to_rep(mesh, v) for v in vmas] + fail = [r if not _valid_repeats(mesh, r, n) else no_fail for n, r in zip(names, reps)] if any(f is not no_fail for f in fail): raise _RepError(fail) @@ -961,17 +966,17 @@ def _maybe_check_special(outs): raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None class ShardMapTrace(core.Trace): - __slots__ = ("mesh", "auto", "check", "context_mesh") + __slots__ = ("mesh", "manual_axes", "check", "context_mesh") mesh: Mesh - auto: frozenset[AxisName] + manual_axes: frozenset[AxisName] check: bool context_mesh: AbstractMesh - def __init__(self, mesh, auto, check, context_mesh): + def __init__(self, mesh, manual_axes, check, context_mesh): super().__init__() self.mesh = mesh - self.auto = auto + self.manual_axes = manual_axes self.check = check self.context_mesh = context_mesh @@ -1030,8 +1035,8 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. del prim, jvp, symbolic_zeros in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) - out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, - self.check, self.context_mesh) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check, self.context_mesh) return map(partial(ShardMapTracer, self), out_vma, out_vals) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, @@ -1043,8 +1048,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) - out_vals, out_vma = _run_shmap(fun, self.mesh, self.auto, in_vals, in_vma, - self.check, self.context_mesh) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check, self.context_mesh) return map(partial(ShardMapTracer, self), out_vma, out_vals) @@ -1065,7 +1070,7 @@ def aval(self): aval = core.get_aval(self.val) out = core.mapped_aval(self._trace.mesh.size, 0, aval) new_sharding = NamedSharding( - _as_manual_mesh(self._trace.mesh, self._trace.auto), + _as_manual_mesh(self._trace.mesh, self._trace.manual_axes), out.sharding.spec) # pytype: disable=attribute-error vma = self.vma if config._check_vma.value else frozenset() return out.update(sharding=new_sharding, vma=vma) @@ -1136,7 +1141,7 @@ def _shard_map_batch( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_vma: bool, - auto: frozenset) -> Sequence[batching.BatchTracer]: + manual_axes: frozenset) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): raise NotImplementedError @@ -1161,7 +1166,7 @@ def new_out_names_thunk(): new_params = dict(mesh=mesh, in_names=new_in_names, out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) with core.set_current_trace(trace.parent_trace): out_vals = prim.bind(fun, *in_vals, **new_params) make_tracer = partial(batching.BatchTracer, trace, @@ -1184,7 +1189,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names): # Autodiff def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_vma, auto): + out_names_thunk, check_vma, manual_axes): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] @@ -1199,7 +1204,7 @@ def new_out_names_thunk(): return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) primal_out, tangent_out = tree_unflatten(out_tree(), result) @@ -1210,18 +1215,18 @@ def new_out_names_thunk(): def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_vma, auto): + out_names_thunk, check_vma, manual_axes): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - in_avals_sharded = map(partial(_shard_aval, mesh, auto, check_vma), + in_avals_sharded = map(partial(_shard_aval, mesh, manual_axes, check_vma), unk_in_names, in_avals) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) f = _promote_scalar_residuals(f) f_known, aux = pe.partial_eval_wrapper_nounits2( f, (*in_knowns,), (*in_avals_sharded,)) - all_names = _all_newly_manual_mesh_names(mesh, auto) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) @as_hashable_function(closure=out_names_thunk) def known_out_names(): @@ -1236,7 +1241,7 @@ def known_out_names(): known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() @@ -1267,7 +1272,7 @@ def known_out_names(): out_avals_sharded = [v.aval for v in jaxpr.outvars] unk_params = dict(mesh=mesh, in_names=unk_in_names, out_names=unk_out_names, jaxpr=jaxpr, - check_vma=check_vma, auto=auto) + check_vma=check_vma, manual_axes=manual_axes) out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_names, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) @@ -1282,12 +1287,12 @@ def known_out_names(): def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_vma, auto): + out_names_thunk, check_vma, manual_axes): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - all_names = _all_newly_manual_mesh_names(mesh, auto) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) @as_hashable_function(closure=linearize_outs_thunk) def fwd_out_names_thunk(): @@ -1303,7 +1308,8 @@ def fwd_out_names_thunk(): return (*res_names, *out_names) fwd_params = dict( mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_vma=check_vma, auto=auto) + out_names_thunk=fwd_out_names_thunk, check_vma=check_vma, + manual_axes=manual_axes) all_fwd_results = shard_map_p.bind_with_trace( trace.parent_trace, (f_primal, *primals), fwd_params) res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() @@ -1313,8 +1319,8 @@ def fwd_out_names_thunk(): residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto)), + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() @@ -1341,7 +1347,7 @@ def tangent_out_names_thunk(): return tangent_out_names tangent_params = dict( mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_vma=check_vma, auto=auto) + check_vma=check_vma, manual_axes=manual_axes) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here def f_tangent(*args): @@ -1393,26 +1399,27 @@ def fun(*res_and_args): def _unmentioned2(mesh: Mesh, names: AxisNames, - auto: frozenset[AxisName]) -> list[AxisName]: + manual_axes: frozenset[AxisName]) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over - # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} | auto - return [n for n in _all_mesh_names_except_spmd(mesh, auto) + # more chips than required in the transpose-no-check-vma case. + name_set = {n for ns in names.values() for n in ns} + return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes) if n not in name_set] def _shard_map_transpose(out_cts, *args, jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_vma, auto): + check_vma, manual_axes): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ - ad.Zero(_shard_aval(mesh, auto, check_vma, ns, x.aval)) + ad.Zero(_shard_aval(mesh, manual_axes, check_vma, ns, x.aval)) if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, manual_axes)))) for ns, x in zip(out_names, out_cts) ] args = tuple(x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, auto, check_vma, ns, x.aval)) + ad.UndefinedPrimal( + _shard_aval(mesh, manual_axes, check_vma, ns, x.aval)) for ns, x in zip(in_names, args)) all_args, in_tree = tree_flatten((out_cts, args)) @@ -1429,7 +1436,7 @@ def fun_trans_callable(out_cts, args): _, in_ct_names = partition_list(in_undef, in_names) in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, ns, x.aval)) if type(x) is ad.Zero else x if check_vma - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, manual_axes))) for ns, x in zip(in_ct_names, in_cts)] res_zeros = [ad_util.zero_from_primal(r) for r in res] return merge_lists(in_undef, res_zeros, in_cts) @@ -1449,7 +1456,7 @@ def new_out_names_thunk(): out_flat = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) except (FloatingPointError, ZeroDivisionError) as e: print("Invalid nan value encountered in the backward pass of a shard_map " "function. Calling the de-optimized backward pass.") @@ -1460,7 +1467,7 @@ def new_out_names_thunk(): _ = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), out_names_thunk=new_out_names_thunk, check_vma=check_vma, - auto=auto) + manual_axes=manual_axes) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: @@ -1476,8 +1483,8 @@ def _partial_eval_jaxpr_custom_rule( ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], list[core.Var]]: jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] - check_vma, auto = eqn.params['check_vma'], eqn.params['auto'] - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes'] + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) num_out_primals = len(jaxpr_known.outvars) - num_res @@ -1487,8 +1494,8 @@ def _partial_eval_jaxpr_custom_rule( out_fwd = [idx_map.get(id(v)) for v in res_vars] which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] mesh = eqn.params['mesh'] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto)), + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) @@ -1503,7 +1510,7 @@ def _partial_eval_jaxpr_custom_rule( for var, w in zip(jaxpr_staged.invars[:num_res], which): if w: rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore - if check_vma else {0: _all_newly_manual_mesh_names(mesh, auto)}) + if check_vma else {0: _all_newly_manual_mesh_names(mesh, manual_axes)}) residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval))) staged_in_res_names.append(rn) if check_vma: @@ -1512,7 +1519,8 @@ def _partial_eval_jaxpr_custom_rule( for var, o in zip(res_vars, out_fwd) if o is None ] else: - out_res_names_known = [{0: _all_newly_manual_mesh_names(mesh, auto)}] * sum(which) + out_res_names_known = [ + {0: _all_newly_manual_mesh_names(mesh, manual_axes)}] * sum(which) params_known, params_staged = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, out_res_names_known, staged_in_res_names, @@ -1588,14 +1596,14 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd( - mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: + mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() spmd_names = axis_env.spmd_axis_names - return tuple(name for name in mesh.axis_names if name not in spmd_names and - name not in auto) + return tuple(name for name in mesh.axis_names + if name not in spmd_names and name in manual_axes) def _all_newly_manual_mesh_names( - mesh: Mesh, auto: frozenset[AxisName]) -> tuple[AxisName, ...]: + mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() vmap_spmd_names = set(axis_env.spmd_axis_names) if not (ctx_mesh := get_abstract_mesh()).empty: @@ -1605,7 +1613,8 @@ def _all_newly_manual_mesh_names( # TODO(mattjj): remove this mechanism when we revise mesh scopes already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names return tuple(name for name in mesh.axis_names - if name not in auto | vmap_spmd_names | already_manual_names) + if (name not in vmap_spmd_names | already_manual_names and + name in manual_axes)) # DCE @@ -1616,9 +1625,9 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn if not any(used_outputs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] - auto = eqn.params["auto"] + manual_axes = eqn.params["manual_axes"] check_vma = eqn.params["check_vma"] - with _extend_axis_env(mesh, auto), config._check_vma(check_vma): + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None