Skip to content

MegaScale discovery is ran twice again #8954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tengyifei opened this issue Apr 9, 2025 · 2 comments
Open

MegaScale discovery is ran twice again #8954

tengyifei opened this issue Apr 9, 2025 · 2 comments
Labels
bug Something isn't working xla:tpu TPU specific issues and PRs

Comments

@tengyifei
Copy link
Collaborator

One of the PyTorch/XLA operations used by https://github.com/AI-Hypercomputer/torchprime Llama is triggering JAX MegaScale discovery again. This bug tracks identifying that operation and removing the jax_env_context workaround.

@ysiraichi
Copy link
Collaborator

Could you either provide more context or label this issue accordingly?

@tengyifei tengyifei added the bug Something isn't working label Apr 16, 2025
@tengyifei
Copy link
Collaborator Author

Basically PyTorch/XLA uses JAX for pallas kernels etc. It's really easy to accidentally trigger JAX backend initialization. When PyTorch/XLA already initialized its own TPU backend and JAX also tries to initialize its own TPU backend, the second initialization will hang in multi-slice (DCN network) environments. We have seen this bug again and again and this is the latest incarnation.

@ysiraichi ysiraichi added the xla:tpu TPU specific issues and PRs label Apr 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working xla:tpu TPU specific issues and PRs
Projects
None yet
Development

No branches or pull requests

2 participants