Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #23619: Fix the EmitSort checking after enabling NVLS and user buffer #23935

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

PR #23619: Fix the EmitSort checking after enabling NVLS and user buffer

Imported from GitHub PR #23619

There is a reported bug from NVIDIA that running Midjourney model triggers XLA error after enabling NVLS and user buffer by setting NCCL_NVLS_ENABLE=1 and --xla_gpu_enable_nccl_user_buffers=true:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/xla/xla/service/gpu/ir_emitter_unnested.cc:1676) LayoutUtil::LayoutsInShapesEqual(keys_shape, sort->operand(i)->shape()).

This is because after enabling NVLS and user buffer, one of the operand of sort operation is from a different memory space (user buffer), and the previous LayoutsInShapesEqual check is too strong to pass as it also checks if operands are from the same memory space.

This MR makes the sort layout check weaker as operands do not have to be in the same memory space as long as they all on the device.
Copybara import of the project:

--
8265352 by Chenhao Jiang [email protected]:

Making the sort layout check weaker as operands do not have to be in the same memory space as long as they all on the device.

Merging this change closes #23619

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23619 from serach24:chenhao/fix_nvsl_check_failed 8265352

@copybara-service copybara-service bot force-pushed the test_738419324 branch 2 times, most recently from d24947d to f5741c4 Compare March 19, 2025 17:31
Imported from GitHub PR #23619

There is a reported bug from NVIDIA that running Midjourney model triggers XLA error after enabling NVLS and user buffer by setting NCCL_NVLS_ENABLE=1 and --xla_gpu_enable_nccl_user_buffers=true:
```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/xla/xla/service/gpu/ir_emitter_unnested.cc:1676) LayoutUtil::LayoutsInShapesEqual(keys_shape, sort->operand(i)->shape()).
```
This is because after enabling NVLS and user buffer, one of the operand of `sort` operation is from a different memory space (user buffer), and the previous `LayoutsInShapesEqual` check is too strong to pass as it also checks if operands are from the same memory space.

This MR makes the sort layout check weaker as operands do not have to be in the same memory space as long as they all on the device.
Copybara import of the project:

--
8265352 by Chenhao Jiang <[email protected]>:

Making the sort layout check weaker as operands do not have to be in the same memory space as long as they all on the device.

Merging this change closes #23619

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23619 from serach24:chenhao/fix_nvsl_check_failed 8265352
PiperOrigin-RevId: 738419324
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant