You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
One of the PyTorch/XLA operations used by https://github.com/AI-Hypercomputer/torchprime Llama is triggering JAX MegaScale discovery again. This bug tracks identifying that operation and removing the jax_env_context workaround.
The text was updated successfully, but these errors were encountered:
Basically PyTorch/XLA uses JAX for pallas kernels etc. It's really easy to accidentally trigger JAX backend initialization. When PyTorch/XLA already initialized its own TPU backend and JAX also tries to initialize its own TPU backend, the second initialization will hang in multi-slice (DCN network) environments. We have seen this bug again and again and this is the latest incarnation.
One of the PyTorch/XLA operations used by https://github.com/AI-Hypercomputer/torchprime Llama is triggering JAX MegaScale discovery again. This bug tracks identifying that operation and removing the
jax_env_context
workaround.The text was updated successfully, but these errors were encountered: