Skip to content

Commit 38483f7

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add conv_general_dilated sharding rule
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding) And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule. Fixes #28090 PiperOrigin-RevId: 748736039
1 parent 0bf3e9b commit 38483f7

File tree

3 files changed

+106
-31
lines changed

3 files changed

+106
-31
lines changed

jax/_src/lax/convolution.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class ConvDimensionNumbers(NamedTuple):
5353
None,
5454
]
5555

56+
# TODO(yashkatariya): conv_general_dilated should take `out_sharding` argument
57+
# similar to `dot_general`
5658
def conv_general_dilated(
5759
lhs: Array, rhs: Array, window_strides: Sequence[int],
5860
padding: str | Sequence[tuple[int, int]],
@@ -415,6 +417,26 @@ def _conv_general_dilated_shape_rule(
415417
return tuple(np.take(out_trans, np.argsort(out_perm)))
416418

417419

420+
def _conv_general_dilated_sharding_rule(
421+
lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding,
422+
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
423+
batch_group_count, **unused_kwargs):
424+
# Only allow if rhs is fully replicated and lhs's feature dim is not sharded
425+
if ((rhs.sharding.mesh.empty or rhs.sharding.is_fully_replicated) and
426+
lhs.sharding.spec[dimension_numbers.lhs_spec[1]] is None):
427+
out_shape = _conv_general_dilated_shape_rule(
428+
lhs, rhs, window_strides=window_strides, padding=padding,
429+
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation,
430+
dimension_numbers=dimension_numbers,
431+
feature_group_count=feature_group_count,
432+
batch_group_count=batch_group_count)
433+
return lax.slicing._get_sharding_for_varying_out_shape(
434+
out_shape, lhs, "conv_general_dilated")
435+
# TODO(yashkatariya): In this case, just let the user specify the out_sharding
436+
# via `out_sharding` argument to `conv_general_dilated`.
437+
raise core.ShardingTypeError(
438+
"Please file an issue at https://github.com/jax-ml/jax/issues")
439+
418440
def _conv_general_dilated_dtype_rule(
419441
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
420442
dimension_numbers, preferred_element_type, **unused_kwargs):
@@ -635,6 +657,7 @@ def _conv_general_dilated_batch_rule(
635657
conv_general_dilated_p = lax.standard_primitive(
636658
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
637659
'conv_general_dilated',
660+
sharding_rule=_conv_general_dilated_sharding_rule,
638661
vma_rule=partial(core.standard_vma_rule, 'conv_general_dilated'))
639662

640663
ad.defbilinear(conv_general_dilated_p,
@@ -713,21 +736,18 @@ def _conv_general_dilated_lower(
713736
# TODO(https://github.com/openxla/stablehlo/issues/1268)
714737
raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count")
715738
if all(core.is_constant_shape(p) for p in padding):
716-
return [
717-
hlo.convolution(
718-
mlir.aval_to_ir_type(aval_out),
719-
lhs,
720-
rhs,
721-
dimension_numbers=dnums,
722-
feature_group_count=mlir.i64_attr(feature_group_count),
723-
batch_group_count=mlir.i64_attr(batch_group_count),
724-
window_strides=mlir.dense_int_array(window_strides),
725-
padding=mlir.dense_int_elements(padding),
726-
lhs_dilation=mlir.dense_int_array(lhs_dilation),
727-
rhs_dilation=mlir.dense_int_array(rhs_dilation),
728-
window_reversal=window_reversal,
729-
precision_config=lax.precision_attr(precision))
730-
]
739+
out = hlo.convolution(
740+
mlir.aval_to_ir_type(aval_out), lhs, rhs,
741+
dimension_numbers=dnums,
742+
feature_group_count=mlir.i64_attr(feature_group_count),
743+
batch_group_count=mlir.i64_attr(batch_group_count),
744+
window_strides=mlir.dense_int_array(window_strides),
745+
padding=mlir.dense_int_elements(padding),
746+
lhs_dilation=mlir.dense_int_array(lhs_dilation),
747+
rhs_dilation=mlir.dense_int_array(rhs_dilation),
748+
window_reversal=window_reversal,
749+
precision_config=lax.precision_attr(precision))
750+
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
731751
else:
732752
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
733753
# spatial dimension.

jax/_src/lax/windowed_reductions.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -520,21 +520,11 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
520520

521521
def reduce_window_sharding_rule(operand, window_dimensions, window_strides,
522522
padding, base_dilation, window_dilation):
523-
if base_dilation is None:
524-
base_dilation = [1] * operand.ndim
525-
if window_dilation is None:
526-
window_dilation = [1] * operand.ndim
527-
528-
for spec, wdim, ws, pd, bd, wdil in zip(
529-
operand.sharding.spec, window_dimensions, window_strides, padding,
530-
base_dilation, window_dilation):
531-
if spec is None:
532-
continue
533-
if not (wdim == 1 and ws == 1 and pd == (0, 0) and bd == 1 and wdil == 1):
534-
raise core.ShardingTypeError(
535-
"Only trivial windowing is supported along non-replicated"
536-
f" dimensions. Got {operand.sharding.spec=}")
537-
return operand.sharding
523+
out_shape = reduce_window_shape_tuple(
524+
operand.shape, window_dimensions, window_strides, padding, base_dilation,
525+
window_dilation)
526+
return lax.slicing._get_sharding_for_varying_out_shape(
527+
out_shape, operand, 'reduce_window')
538528

539529
reduce_window_sum_p = lax.standard_primitive(
540530
_reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum',
@@ -680,8 +670,14 @@ def _select_and_scatter_shape_rule(
680670
raise TypeError(msg.format(window_strides, window_dimensions))
681671
return operand.shape
682672

673+
def _select_and_scatter_sharding_rule(
674+
operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
675+
scatter_consts, window_dimensions, window_strides, padding):
676+
return operand.sharding
677+
683678
select_and_scatter_p = lax.standard_primitive(
684679
_select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter',
680+
sharding_rule=_select_and_scatter_sharding_rule,
685681
vma_rule=partial(core.standard_vma_rule, 'select_and_scatter'))
686682

687683
def _select_and_scatter_lower(
@@ -722,7 +718,8 @@ def _select_and_scatter_lower(
722718
*scatter.arguments,
723719
dim_var_values=ctx.dim_var_values)
724720
hlo.return_(mlir.flatten_ir_values(out_nodes))
725-
return op.results
721+
return [mlir.lower_with_sharding_in_types(ctx, r, aval)
722+
for r, aval in zip(op.results, ctx.avals_out)]
726723

727724
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
728725

@@ -731,6 +728,11 @@ def _select_and_scatter_add_shape_rule(
731728
padding):
732729
return operand.shape
733730

731+
def _select_and_scatter_add_sharding_rule(
732+
source, operand, *, select_prim, window_dimensions, window_strides,
733+
padding):
734+
return operand.sharding
735+
734736
def _select_and_scatter_add_jvp(
735737
primals, tangents, *, select_prim, window_dimensions, window_strides,
736738
padding):
@@ -779,6 +781,7 @@ def _select_and_scatter_add_batch_rule(
779781
select_and_scatter_add_p = lax.standard_primitive(
780782
_select_and_scatter_add_shape_rule, lax._input_dtype,
781783
'select_and_scatter_add',
784+
sharding_rule=_select_and_scatter_add_sharding_rule,
782785
vma_rule=partial(core.standard_vma_rule, 'select_and_scatter_add'))
783786

784787
ad.primitive_transposes[select_and_scatter_add_p] = \

tests/pjit_test.py

+52
Original file line numberDiff line numberDiff line change
@@ -7517,6 +7517,58 @@ def f2(x, i, j):
75177517
return x.at[i].set(x_j)
75187518
f2(x,i,j) # doesn't crash
75197519

7520+
@jtu.with_explicit_mesh((4, 2), ('x', 'y'))
7521+
def test_conv_general_dilated(self, mesh):
7522+
arr = jax.device_put(np.zeros((16, 128, 8)), P('x', 'y'))
7523+
7524+
@jax.jit
7525+
def f(x):
7526+
# Conv1D across sharded y-axis:
7527+
out = jax.lax.conv_general_dilated(
7528+
x, np.zeros((5, 8, 10)),
7529+
window_strides=(1,), padding='SAME', feature_group_count=1,
7530+
lhs_dilation=(1,), rhs_dilation=(1,),
7531+
dimension_numbers=('NWC', 'WIO', 'NWC'))
7532+
self.assertEqual(out.aval.sharding.spec, P('x', 'y', None))
7533+
# Max pooling along sharded y-axis.
7534+
out2 = jax.lax.reduce_window(
7535+
out, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME')
7536+
self.assertEqual(out2.aval.sharding.spec, P('x', 'y', None))
7537+
return out2
7538+
7539+
out = f(arr)
7540+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y', None)))
7541+
self.check_wsc_in_lowered(f.lower(arr).as_text())
7542+
7543+
jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash
7544+
7545+
with self.assertRaises(core.ShardingTypeError):
7546+
arr2 = jax.device_put(np.zeros((16, 128, 8)), P('x', None, 'y'))
7547+
f(arr2)
7548+
7549+
@parameterized.named_parameters(
7550+
('spec1', P('x', 'y', None)),
7551+
('spec2', P('x', None, 'y')),
7552+
('spec3', P(None, 'x', 'y')),
7553+
('spec4', P(('x', 'y'), None, None))
7554+
)
7555+
@jtu.with_explicit_mesh((4, 2), ('x', 'y'))
7556+
def test_reduce_window(self, spec, mesh):
7557+
arr = jax.device_put(np.zeros((16, 128, 8)), spec)
7558+
7559+
@jax.jit
7560+
def f(x):
7561+
out = jax.lax.reduce_window(
7562+
x, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME')
7563+
self.assertEqual(out.aval.sharding.spec, spec)
7564+
return out
7565+
7566+
out = f(arr)
7567+
self.assertEqual(out.sharding, NamedSharding(mesh, spec))
7568+
self.check_wsc_in_lowered(f.lower(arr).as_text())
7569+
7570+
jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash
7571+
75207572

75217573
@jtu.pytest_mark_if_available('multiaccelerator')
75227574
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)