Skip to content

Remove scheduling annotations on the cloned computation for collective groups #24794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

chaserileyroberts
Copy link
Contributor

@chaserileyroberts chaserileyroberts commented Apr 8, 2025

Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:

Example code

@jax.jit
def bidir_comms(a):
    b = jax.lax.ppermute(a, "i", perm_up)
    c = jax.lax.ppermute(a, "i", perm_down)
    return b, c

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i'))
def groups(a):
   # Running the collective groups under a scheduling group.
   with set_xla_metadata( _scheduling_group_id='1'):
      with set_xla_metadata(_collectives_group="", inlineable="false"):
          b, c = bidir_comms(a)
    return b + c

Would crash with an error like

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1. 

This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.

This is a simple solution for now, but the real crux of the issue is that with set_xla_metadata is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner Call operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a jax.jit.

@reedwm reedwm requested a review from rosiezou April 8, 2025 15:58
copybara-service bot pushed a commit that referenced this pull request Apr 8, 2025
…r collective groups

Imported from GitHub PR #24794

Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:

Example code

```python
@jax.jit
def bidir_comms(a):
    b = jax.lax.ppermute(a, "i", perm_up)
    c = jax.lax.ppermute(a, "i", perm_down)
    return b, c

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i'))
def groups(a):
   # Running the collective groups under a scheduling group.
   with set_xla_metadata( _scheduling_group_id='1'):
      with set_xla_metadata(_collectives_group="", inlineable="false"):
          b, c = bidir_comms(a)
    return b + c
```
Would crash with an error like

```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1.
```

This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.

This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`.
Copybara import of the project:

--
a266d26 by chaser <[email protected]>:

Remove scheduling annotations on the cloned computation

Merging this change closes #24794

FUTURE_COPYBARA_INTEGRATE_REVIEW=#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26
PiperOrigin-RevId: 745224483
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 8, 2025
…r collective groups

Imported from GitHub PR openxla/xla#24794

Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:

Example code

```python
@jax.jit
def bidir_comms(a):
    b = jax.lax.ppermute(a, "i", perm_up)
    c = jax.lax.ppermute(a, "i", perm_down)
    return b, c

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i'))
def groups(a):
   # Running the collective groups under a scheduling group.
   with set_xla_metadata( _scheduling_group_id='1'):
      with set_xla_metadata(_collectives_group="", inlineable="false"):
          b, c = bidir_comms(a)
    return b + c
```
Would crash with an error like

```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1.
```

This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.

This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`.
Copybara import of the project:

--
a266d26a7f28b008869be12a90b7a4fdf61f219c by chaser <[email protected]>:

Remove scheduling annotations on the cloned computation

Merging this change closes #24794

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26a7f28b008869be12a90b7a4fdf61f219c
PiperOrigin-RevId: 745224483
copybara-service bot pushed a commit that referenced this pull request Apr 8, 2025
…r collective groups

Imported from GitHub PR #24794

Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:

Example code

```python
@jax.jit
def bidir_comms(a):
    b = jax.lax.ppermute(a, "i", perm_up)
    c = jax.lax.ppermute(a, "i", perm_down)
    return b, c

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i'))
def groups(a):
   # Running the collective groups under a scheduling group.
   with set_xla_metadata( _scheduling_group_id='1'):
      with set_xla_metadata(_collectives_group="", inlineable="false"):
          b, c = bidir_comms(a)
    return b + c
```
Would crash with an error like

```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1.
```

This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.

This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`.
Copybara import of the project:

--
a266d26 by chaser <[email protected]>:

Remove scheduling annotations on the cloned computation

Merging this change closes #24794

FUTURE_COPYBARA_INTEGRATE_REVIEW=#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26
PiperOrigin-RevId: 745224483
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 8, 2025
…r collective groups

Imported from GitHub PR openxla/xla#24794

Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:

Example code

```python
@jax.jit
def bidir_comms(a):
    b = jax.lax.ppermute(a, "i", perm_up)
    c = jax.lax.ppermute(a, "i", perm_down)
    return b, c

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i'))
def groups(a):
   # Running the collective groups under a scheduling group.
   with set_xla_metadata( _scheduling_group_id='1'):
      with set_xla_metadata(_collectives_group="", inlineable="false"):
          b, c = bidir_comms(a)
    return b + c
```
Would crash with an error like

```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1.
```

This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.

This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`.
Copybara import of the project:

--
a266d26a7f28b008869be12a90b7a4fdf61f219c by chaser <[email protected]>:

Remove scheduling annotations on the cloned computation

Merging this change closes #24794

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26a7f28b008869be12a90b7a4fdf61f219c
PiperOrigin-RevId: 745224483
@copybara-service copybara-service bot closed this in d1e16f1 Apr 8, 2025
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 8, 2025
…r collective groups

Imported from GitHub PR openxla/xla#24794

Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:

Example code

```python
@jax.jit
def bidir_comms(a):
    b = jax.lax.ppermute(a, "i", perm_up)
    c = jax.lax.ppermute(a, "i", perm_down)
    return b, c

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i'))
def groups(a):
   # Running the collective groups under a scheduling group.
   with set_xla_metadata( _scheduling_group_id='1'):
      with set_xla_metadata(_collectives_group="", inlineable="false"):
          b, c = bidir_comms(a)
    return b + c
```
Would crash with an error like

```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1.
```

This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.

This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`.
Copybara import of the project:

--
a266d26a7f28b008869be12a90b7a4fdf61f219c by chaser <[email protected]>:

Remove scheduling annotations on the cloned computation

Merging this change closes #24794

PiperOrigin-RevId: 745248689
alekstheod pushed a commit to ROCm/xla that referenced this pull request Apr 11, 2025
…tion for collective groups

Imported from GitHub PR openxla#24794

Previously, when users try to combine using the collective groups with the scheduling groups, you could run into issues like this:

Example code

```python
@jax.jit
def bidir_comms(a):
    b = jax.lax.ppermute(a, "i", perm_up)
    c = jax.lax.ppermute(a, "i", perm_down)
    return b, c

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=P(None, 'i'), out_specs=P(None, 'i'))
def groups(a):
   # Running the collective groups under a scheduling group.
   with set_xla_metadata( _scheduling_group_id='1'):
      with set_xla_metadata(_collectives_group="", inlineable="false"):
          b, c = bidir_comms(a)
    return b + c
```
Would crash with an error like

```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There is a scheduling group which exceeds the overlap limits. Annotation id: 1. It needs 2 kGpuAsyncStreamCollectives resources, but the limit is 1.
```

This is because the instructions within the async computation also included the scheduling annotations. When the LHS would look within this computation for scheduling, it would see all of the communication operations being labeled with the same group and crash from the overlap limit check. Removing this annotation when creating the async instructions fixes this issue.

This is a simple solution for now, but the real crux of the issue is that `with set_xla_metadata` is applying frontend attributes to every operation within the context, even though we would prefer it to only be on the inner `Call` operation. We should consider adding a different JAX API that applies attributes to only the call operation created from a `jax.jit`.
Copybara import of the project:

--
a266d26 by chaser <[email protected]>:

Remove scheduling annotations on the cloned computation

Merging this change closes openxla#24794

COPYBARA_INTEGRATE_REVIEW=openxla#24794 from chaserileyroberts:chase/collective_groups_scheduling a266d26
PiperOrigin-RevId: 745248689
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants