-
Notifications
You must be signed in to change notification settings - Fork 568
Remove scheduling annotations on the cloned computation for collective groups #24794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
chaserileyroberts
wants to merge
1
commit into
openxla:main
from
chaserileyroberts:chase/collective_groups_scheduling
Closed
Remove scheduling annotations on the cloned computation for collective groups #24794
chaserileyroberts
wants to merge
1
commit into
openxla:main
from
chaserileyroberts:chase/collective_groups_scheduling
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rosiezou
approved these changes
Apr 8, 2025
copybara-service bot
pushed a commit
that referenced
this pull request
Apr 8, 2025
…r collective groups Imported from GitHub PR #24794 Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this: Example code ```python @jax.jit def bidir_comms(a): b = jax.lax.ppermute(a, "i", perm_up) c = jax.lax.ppermute(a, "i", perm_down) return b, c @jax.jit @partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i')) def groups(a): # Running the collective groups under a scheduling group. with set_xla_metadata( _scheduling_group_id='1'): with set_xla_metadata(_collectives_group="", inlineable="false"): b, c = bidir_comms(a) return b + c ``` Would crash with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. ``` This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue. This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`. Copybara import of the project: -- a266d26 by chaser <[email protected]>: Remove scheduling annotations on the cloned computation Merging this change closes #24794 FUTURE_COPYBARA_INTEGRATE_REVIEW=#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26 PiperOrigin-RevId: 745224483
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Apr 8, 2025
…r collective groups Imported from GitHub PR openxla/xla#24794 Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this: Example code ```python @jax.jit def bidir_comms(a): b = jax.lax.ppermute(a, "i", perm_up) c = jax.lax.ppermute(a, "i", perm_down) return b, c @jax.jit @partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i')) def groups(a): # Running the collective groups under a scheduling group. with set_xla_metadata( _scheduling_group_id='1'): with set_xla_metadata(_collectives_group="", inlineable="false"): b, c = bidir_comms(a) return b + c ``` Would crash with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. ``` This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue. This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`. Copybara import of the project: -- a266d26a7f28b008869be12a90b7a4fdf61f219c by chaser <[email protected]>: Remove scheduling annotations on the cloned computation Merging this change closes #24794 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26a7f28b008869be12a90b7a4fdf61f219c PiperOrigin-RevId: 745224483
rosiezou
approved these changes
Apr 8, 2025
rosiezou
approved these changes
Apr 8, 2025
copybara-service bot
pushed a commit
that referenced
this pull request
Apr 8, 2025
…r collective groups Imported from GitHub PR #24794 Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this: Example code ```python @jax.jit def bidir_comms(a): b = jax.lax.ppermute(a, "i", perm_up) c = jax.lax.ppermute(a, "i", perm_down) return b, c @jax.jit @partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i')) def groups(a): # Running the collective groups under a scheduling group. with set_xla_metadata( _scheduling_group_id='1'): with set_xla_metadata(_collectives_group="", inlineable="false"): b, c = bidir_comms(a) return b + c ``` Would crash with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. ``` This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue. This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`. Copybara import of the project: -- a266d26 by chaser <[email protected]>: Remove scheduling annotations on the cloned computation Merging this change closes #24794 FUTURE_COPYBARA_INTEGRATE_REVIEW=#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26 PiperOrigin-RevId: 745224483
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Apr 8, 2025
…r collective groups Imported from GitHub PR openxla/xla#24794 Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this: Example code ```python @jax.jit def bidir_comms(a): b = jax.lax.ppermute(a, "i", perm_up) c = jax.lax.ppermute(a, "i", perm_down) return b, c @jax.jit @partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i')) def groups(a): # Running the collective groups under a scheduling group. with set_xla_metadata( _scheduling_group_id='1'): with set_xla_metadata(_collectives_group="", inlineable="false"): b, c = bidir_comms(a) return b + c ``` Would crash with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. ``` This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue. This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`. Copybara import of the project: -- a266d26a7f28b008869be12a90b7a4fdf61f219c by chaser <[email protected]>: Remove scheduling annotations on the cloned computation Merging this change closes #24794 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26a7f28b008869be12a90b7a4fdf61f219c PiperOrigin-RevId: 745224483
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Apr 8, 2025
…r collective groups Imported from GitHub PR openxla/xla#24794 Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this: Example code ```python @jax.jit def bidir_comms(a): b = jax.lax.ppermute(a, "i", perm_up) c = jax.lax.ppermute(a, "i", perm_down) return b, c @jax.jit @partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i')) def groups(a): # Running the collective groups under a scheduling group. with set_xla_metadata( _scheduling_group_id='1'): with set_xla_metadata(_collectives_group="", inlineable="false"): b, c = bidir_comms(a) return b + c ``` Would crash with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. ``` This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue. This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`. Copybara import of the project: -- a266d26a7f28b008869be12a90b7a4fdf61f219c by chaser <[email protected]>: Remove scheduling annotations on the cloned computation Merging this change closes #24794 PiperOrigin-RevId: 745248689
alekstheod
pushed a commit
to ROCm/xla
that referenced
this pull request
Apr 11, 2025
…tion for collective groups Imported from GitHub PR openxla#24794 Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this: Example code ```python @jax.jit def bidir_comms(a): b = jax.lax.ppermute(a, "i", perm_up) c = jax.lax.ppermute(a, "i", perm_down) return b, c @jax.jit @partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i')) def groups(a): # Running the collective groups under a scheduling group. with set_xla_metadata( _scheduling_group_id='1'): with set_xla_metadata(_collectives_group="", inlineable="false"): b, c = bidir_comms(a) return b + c ``` Would crash with an error like ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. ``` This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue. This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`. Copybara import of the project: -- a266d26 by chaser <[email protected]>: Remove scheduling annotations on the cloned computation Merging this change closes openxla#24794 COPYBARA_INTEGRATE_REVIEW=openxla#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26 PiperOrigin-RevId: 745248689
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:
Example code
Would crash with an error like
This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.
This is a simple solution for now, but the real crux of the issue is that
with set_xla_metadata
is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the innerCall
operation. We should consider adding a different JAX API that applies attributes to only the call operation created from ajax.jit
.