-
Notifications
You must be signed in to change notification settings - Fork 3k
Request for supporting conv1d and pooling ops with explicit sharding. #28090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Labels
enhancement
New feature or request
Comments
copybara-service bot
pushed a commit
that referenced
this issue
Apr 24, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding) And since reduce_window is the exact same thing as the above case, do the same in it's sharding rule. Fixes #28090 PiperOrigin-RevId: 748736039
copybara-service bot
pushed a commit
that referenced
this issue
Apr 25, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding) And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule. Fixes #28090 PiperOrigin-RevId: 748736039
charleshofer
pushed a commit
to ROCm/jax
that referenced
this issue
Apr 30, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding) And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule. Fixes jax-ml#28090 PiperOrigin-RevId: 751534065
charleshofer
pushed a commit
to ROCm/jax
that referenced
this issue
May 1, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding) And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule. Fixes jax-ml#28090 PiperOrigin-RevId: 751534065
andportnoy
pushed a commit
to andportnoy/jax
that referenced
this issue
May 2, 2025
This rule only works when rhs is fully replicated or rhs's mesh is empty (i.e. rhs is a numpy array or jnp.array). In this case, we just forward the sharding of lhs to the output (after making sure that the out_shape even divides the sharding) And since reduce_window is the exact same thing as the above case (i.e. lhs sharded, rhs fully replicated), do the same in it's sharding rule. Fixes jax-ml#28090 PiperOrigin-RevId: 751534065
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Dear JAX Team,
We are working on a project that heavily utilizes 1D convolutional and pooling layers across sharded inputs. We would like to migrate to the new explicit sharding api, but found
jax.lax.conv_general_dilated
andjax.lax.reduce_window
to be breaking for our use-case. Please find below the specific details of the operations with a minimal failing example:Would it be possible to support this in future releases?
Many thanks!
The text was updated successfully, but these errors were encountered: