@@ -53,6 +53,8 @@ class ConvDimensionNumbers(NamedTuple):
53
53
None ,
54
54
]
55
55
56
+ # TODO(yashkatariya): conv_general_dilated should take `out_sharding` argument
57
+ # similar to `dot_general`
56
58
def conv_general_dilated (
57
59
lhs : Array , rhs : Array , window_strides : Sequence [int ],
58
60
padding : str | Sequence [tuple [int , int ]],
@@ -415,6 +417,26 @@ def _conv_general_dilated_shape_rule(
415
417
return tuple (np .take (out_trans , np .argsort (out_perm )))
416
418
417
419
420
+ def _conv_general_dilated_sharding_rule (
421
+ lhs : core .ShapedArray , rhs : core .ShapedArray , * , window_strides , padding ,
422
+ lhs_dilation , rhs_dilation , dimension_numbers , feature_group_count ,
423
+ batch_group_count , ** unused_kwargs ):
424
+ # Only allow if rhs is fully replicated and lhs's feature dim is not sharded
425
+ if ((rhs .sharding .mesh .empty or rhs .sharding .is_fully_replicated ) and
426
+ lhs .sharding .spec [dimension_numbers .lhs_spec [1 ]] is None ):
427
+ out_shape = _conv_general_dilated_shape_rule (
428
+ lhs , rhs , window_strides = window_strides , padding = padding ,
429
+ lhs_dilation = lhs_dilation , rhs_dilation = rhs_dilation ,
430
+ dimension_numbers = dimension_numbers ,
431
+ feature_group_count = feature_group_count ,
432
+ batch_group_count = batch_group_count )
433
+ return lax .slicing ._get_sharding_for_varying_out_shape (
434
+ out_shape , lhs , "conv_general_dilated" )
435
+ # TODO(yashkatariya): In this case, just let the user specify the out_sharding
436
+ # via `out_sharding` argument to `conv_general_dilated`.
437
+ raise core .ShardingTypeError (
438
+ "Please file an issue at https://github.com/jax-ml/jax/issues" )
439
+
418
440
def _conv_general_dilated_dtype_rule (
419
441
lhs , rhs , * , window_strides , padding , lhs_dilation , rhs_dilation ,
420
442
dimension_numbers , preferred_element_type , ** unused_kwargs ):
@@ -635,6 +657,7 @@ def _conv_general_dilated_batch_rule(
635
657
conv_general_dilated_p = lax .standard_primitive (
636
658
_conv_general_dilated_shape_rule , _conv_general_dilated_dtype_rule ,
637
659
'conv_general_dilated' ,
660
+ sharding_rule = _conv_general_dilated_sharding_rule ,
638
661
vma_rule = partial (core .standard_vma_rule , 'conv_general_dilated' ))
639
662
640
663
ad .defbilinear (conv_general_dilated_p ,
@@ -713,21 +736,18 @@ def _conv_general_dilated_lower(
713
736
# TODO(https://github.com/openxla/stablehlo/issues/1268)
714
737
raise NotImplementedError ("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count" )
715
738
if all (core .is_constant_shape (p ) for p in padding ):
716
- return [
717
- hlo .convolution (
718
- mlir .aval_to_ir_type (aval_out ),
719
- lhs ,
720
- rhs ,
721
- dimension_numbers = dnums ,
722
- feature_group_count = mlir .i64_attr (feature_group_count ),
723
- batch_group_count = mlir .i64_attr (batch_group_count ),
724
- window_strides = mlir .dense_int_array (window_strides ),
725
- padding = mlir .dense_int_elements (padding ),
726
- lhs_dilation = mlir .dense_int_array (lhs_dilation ),
727
- rhs_dilation = mlir .dense_int_array (rhs_dilation ),
728
- window_reversal = window_reversal ,
729
- precision_config = lax .precision_attr (precision ))
730
- ]
739
+ out = hlo .convolution (
740
+ mlir .aval_to_ir_type (aval_out ), lhs , rhs ,
741
+ dimension_numbers = dnums ,
742
+ feature_group_count = mlir .i64_attr (feature_group_count ),
743
+ batch_group_count = mlir .i64_attr (batch_group_count ),
744
+ window_strides = mlir .dense_int_array (window_strides ),
745
+ padding = mlir .dense_int_elements (padding ),
746
+ lhs_dilation = mlir .dense_int_array (lhs_dilation ),
747
+ rhs_dilation = mlir .dense_int_array (rhs_dilation ),
748
+ window_reversal = window_reversal ,
749
+ precision_config = lax .precision_attr (precision ))
750
+ return [mlir .lower_with_sharding_in_types (ctx , out , aval_out )]
731
751
else :
732
752
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
733
753
# spatial dimension.
0 commit comments