Skip to content

Request for supporting conv1d and pooling ops with explicit sharding. #28090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vdutor opened this issue Apr 17, 2025 · 0 comments · Fixed by #28253
Closed

Request for supporting conv1d and pooling ops with explicit sharding. #28090

vdutor opened this issue Apr 17, 2025 · 0 comments · Fixed by #28253
Assignees
Labels
enhancement New feature or request

Comments

@vdutor
Copy link

vdutor commented Apr 17, 2025

Dear JAX Team,

We are working on a project that heavily utilizes 1D convolutional and pooling layers across sharded inputs. We would like to migrate to the new explicit sharding api, but found jax.lax.conv_general_dilated and jax.lax.reduce_window to be breaking for our use-case. Please find below the specific details of the operations with a minimal failing example:

import jax
jax.config.update('jax_num_cpu_devices', 8)  # Run with 8 devices.
import numpy as np
from jax.experimental.shard import reshard


mesh = jax.make_mesh((2, 4), ('x', 'y'), axis_types=(jax.sharding.AxisType.Explicit,)*2)

with jax.sharding.use_mesh(mesh):
  inputs = reshard(np.zeros((16, 128, 7)), jax.sharding.PartitionSpec('x', 'y'))
  # Conv1D across sharded y-axis:
  _ = jax.lax.conv_general_dilated(
      inputs,
      np.zeros((5, 7, 11)),
      window_strides=(1,),
      padding='SAME',
      feature_group_count=1,
      lhs_dilation=(1,),
      rhs_dilation=(1,),
      dimension_numbers=('NWC', 'WIO', 'NWC'),
  )
  
  # Max pooling along sharded y-axis.
  _ = jax.lax.reduce_window(inputs, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME')

Would it be possible to support this in future releases?

Many thanks!

@vdutor vdutor added the enhancement New feature or request label Apr 17, 2025
@yashk2810 yashk2810 self-assigned this Apr 17, 2025
copybara-service bot pushed a commit that referenced this issue Apr 24, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding)

And since reduce_window is the exact same thing as the above case, do the same in it's sharding rule.

Fixes #28090

PiperOrigin-RevId: 748736039
copybara-service bot pushed a commit that referenced this issue Apr 25, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding)

And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule.

Fixes #28090

PiperOrigin-RevId: 748736039
charleshofer pushed a commit to ROCm/jax that referenced this issue Apr 30, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding)

And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule.

Fixes jax-ml#28090

PiperOrigin-RevId: 751534065
charleshofer pushed a commit to ROCm/jax that referenced this issue May 1, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding)

And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule.

Fixes jax-ml#28090

PiperOrigin-RevId: 751534065
andportnoy pushed a commit to andportnoy/jax that referenced this issue May 2, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding)

And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule.

Fixes jax-ml#28090

PiperOrigin-RevId: 751534065
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants