Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX code is extremely slow on GPUs #24411

Open
AakashKumarNain opened this issue Oct 20, 2024 · 37 comments
Open

JAX code is extremely slow on GPUs #24411

AakashKumarNain opened this issue Oct 20, 2024 · 37 comments
Assignees
Labels
bug Something isn't working performance make things lean and fast

Comments

@AakashKumarNain
Copy link

AakashKumarNain commented Oct 20, 2024

Description

Last week a discussion took place on Twitter where many people noticed that the performance of JAX on GPUs is still subpar compared to PyTorch code. @mattjj and I had a discussion afterwards and we agreed that a repo with minimal codebase that showcase the differences in performance between the two would be great.

GPT-2 is a good model to start with, and here is a repo that contains code both for JAX, and PyTorch for the same model. The instructions provided on the repo are enough to download and run code locally (on a GPU machine).

On my side, these are the results I got on an A100 40G machine:

JAX


jax_run

PyTorch


torch_run

Compared to Torch, JAX is extremely slow here, and I have no idea why. There is a chance of a silent bug somewhere in the JAX code, and I may have overlooked it. Given the number of times I have been through this, I think a fresh set of eyes would do better justice. Please let me know if you need any other information on this from my side.

System info (python version, jaxlib version, accelerator, etc.)

jax[cuda12]
jaxlib==0.4.34
equinox==0.11.8
optax @ git+https://github.com/google-deepmind/optax.git@85378ad4ce1c19dfd218c65873f8941776c3eaca

Optax is installed from git because there was a fix for adamw but that was not the part of the last release.

@AakashKumarNain AakashKumarNain added the bug Something isn't working label Oct 20, 2024
@bchetioui
Copy link
Collaborator

Thanks for filing! I've looked into this a bit, and I just wanted to check with you @AakashKumarNain: are the set ups in both files expected to be the same?

Looking at the very first lines in torch_single_gpu.py and jax_single_gpu.py, I see

B = 8  # micro batch size
T = 1024  # sequence length

in jax_single_gpu.py and

B = 16  # micro batch size
T = 1024  # sequence length

in torch_single_gpu.py.

I believe this means that the JAX model will be running twice as many grad_accum_steps (given that B is in the denominator there). In your screenshots, those seem to be the same, however.

I ran the JAX model on a single A100-40 (SXM) node, and here are the numbers I got:

$ python3 jax_single_gpu.py input.txt
total desired batch size: 524288                                                                                                                                   
=> calculated gradient accumulation steps: 64                                                                                                                      
loaded 338025 tokens                                                                                                                                               
1 epoch = 41 batches                                                                                                                                               
                                                                                                                                                                                                                                                                                                                                     
Loading GPT2 model...                                                                                                                                              
Number of parameters in the model       : 124.48 M                                                                                                                                                                                                                                                                                   
Training...                                                                                                                                                        
                                                                                 
step    0 | loss: 10.890376 | dt: 27241.55ms | tok/sec: 19245.89                                                                                                   
step    1 | loss: 10.333470 | dt: 10384.95ms | tok/sec: 50485.36                                                                                                   
step    2 | loss: 9.422712 | dt: 10379.23ms | tok/sec: 50513.18
...                                                                                                    

which are different from the initial reported numbers. Is the configuration expected to be the same in the repo as in your screenshots?

(I wasn't yet able to run the PyTorch code for comparison due to dependency issues, but hope to get to that soon.)

@bchetioui
Copy link
Collaborator

One suggested improvement! By fully unrolling the scan here, the model runs significantly faster:

total desired batch size: 524288
=> calculated gradient accumulation steps: 64
loaded 338025 tokens
1 epoch = 41 batches

Loading GPT2 model...
Number of parameters in the model       : 124.48 M
Training...

step    0 | loss: 10.890301 | dt: 219098.39ms | tok/sec: 2392.93
step    1 | loss: 10.333476 | dt: 6657.62ms | tok/sec: 78750.01
step    2 | loss: 9.422775 | dt: 6652.12ms | tok/sec: 78815.23
step    3 | loss: 9.101684 | dt: 6637.56ms | tok/sec: 78988.09
step    4 | loss: 8.898127 | dt: 6620.79ms | tok/sec: 79188.10
...

(Unrolling scan/loops is a good way to allow the compiler to make better optimizations in general, at the expense of compile-time.)

@AakashKumarNain
Copy link
Author

AakashKumarNain commented Oct 25, 2024

Thanks for looking into it @bchetioui The benchmark is run for both the setup with a batch size of 8 as I couldn't fit a batch size of 16 with JAX (most probably due to the extra pytree that act as a mask for adamw). I might have forgotten to update the batch info in git.

If you look at the number of batches and the number of grad accumulated steps printed on the screen, both are same meaning the micro batch size is same. Though for the sake of more transparency, I will update the repo, and post the numbers here once again

Re unrolling: I am not worried about the compile time for now. Let me check that. Thanks for the pointers

@jakevdp jakevdp added the performance make things lean and fast label Oct 25, 2024
@AakashKumarNain
Copy link
Author

Updated the repo.

Re unroll: I did it, and even though the I got a nice bump in the number of tokens processed per second, the time taken to process those tokens got extremely worse. I would have rather paid one time compilation penalty. Also, the performance is still nowhere close to torch neither in terms of tokens (163K token processed in torch vs 22K tokens processed in JAX) nor in terms of time taken

@bchetioui
Copy link
Collaborator

Thanks for trying and for the feedback!

the time taken to process those tokens got extremely worse

Can you explain what that means? Presumably, if we're processing more tokens per second, then the time has to be better---or am I missing something? If you're talking about the first step, then that is because the first step includes compilation overhead. (The first call to the function triggers compilation.)

Also, the performance is still nowhere close to torch neither in terms of tokens (163K token processed in torch vs 22K tokens processed in JAX) nor in terms of time taken

Can you give some details about the hardware you are running on? You said you are running on A100-40, but it seems that as posted above I am getting ~80K tokens processed in JAX on my side, running on an NVIDIA A100-SXM4-40GB (output of nvidia-smi). I'm confused about why the numbers that we see are so different, and I'd like to reproduce your numbers. Could you perhaps share the output of nvidia-smi for you?

I am facing the following error when running python3 torch_single_gpu.py input.txt:

ImportError: cannot import name 'get_cuda_stream' from 'triton.runtime.jit' (/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py)

Do you have any idea where that might be coming from? Is there perhaps a requirement missing in requirements.txt---I see that it doesn't contain any information about the required Triton version?

@AakashKumarNain
Copy link
Author

AakashKumarNain commented Oct 26, 2024

Can you explain what that means?

Please take a look at the numbers below:

Screenshot 2024-10-26 at 7 44 31 PM

Can you give some details about the hardware you are running on? You said you are running on A100-40, but it seems that #24411 (comment), running on an NVIDIA A100-SXM4-40GB (output of nvidia-smi). I'm confused about why the numbers that we see are so different, and I'd like to reproduce your numbers. Could you perhaps share the output of nvidia-smi for you?

Sure thing. Here is the output:

Screenshot 2024-10-26 at 9 51 19 AM

Is there perhaps a requirement missing in requirements.txt---I see that it doesn't contain any information about the required Triton version?

No nothing missing from the requirements.txt file. To double check, I just created a new env with that and ran the benchmarks again, and it worked fine. Can you create a new empty env and install everything using the requirements file?
The only guess I have in this regard is that you have a triton package installed in the current env which is incompatible with the torch package

One question for you: How are you unrolling the scan? Maybe we are getting different numbers because of that?

@bchetioui
Copy link
Collaborator

Sure thing. Here is the output:

Thank you for sharing! The chip is the same, so that's not the difference. One difference worth noting (though not sure if it matters) is that I'm using an older CUDA version than you are:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              49W / 400W |  30753MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    105579      C   python3                                   30740MiB |
+---------------------------------------------------------------------------------------+

Can you create a new empty env and install everything using the requirements file?
The only guess I have in this regard is that you have a triton package installed in the current env which is incompatible with the torch package

Thanks, seems like retrying with a new env worked after a couple of tries! (There were some issues due to incomplete installs of torch, that may also be what had happened before.) I managed to reproduce your numbers for torch, making it currently ~2x faster than the JAX implementation on my system (80k vs 160k tokens per second).

This difference is still very much unexpected of course, so I'll keep digging :)

One question for you: How are you unrolling the scan? Maybe we are getting different numbers because of that?

I just added the unroll kwarg to the call to jax.lax.scan:

(x, layer_idx), _ = jax.lax.scan(f, (x, layer_idx), dynamic_layers, unroll=1000)  # unrolling to 12 should suffice, I think

One thing I also noticed is that you're not actually calling into a FlashAttention kernel on the JAX side. The call to

attn = jax.nn.dot_product_attention(q, k, v, is_causal=True)

could be rewritten into

attn = jax.nn.dot_product_attention(q, k, v, is_causal=True, implementation="cudnn")

for that. There is a TODO left by @kaixih to select the best implementation automatically in the code, which currently doesn't happen automatically. As a result, the default dispatches to a soup of primitive ops that XLA can't optimize into a single FlashAttention kernel.

@AakashKumarNain
Copy link
Author

Thanks, seems like retrying with a new env worked after a couple of tries!

Good to know. Yeah, sometimes it happens if there is a silent package conflict or a corrupted install from the cache

I managed to reproduce your numbers for torch, making it currently ~2x faster than the JAX implementation on my system (80k vs 160k tokens per second). This difference is still very much unexpected of course, so I'll keep digging :)

Yes, I was shocked to see the performance difference between the two.

I just added the unroll kwarg to the call to jax.lax.scan:

I did the same but with a much lower number, and you are right 12 per iteration should be enough. Will check that

There is a TODO left by @kaixih to select the best implementation automatically in the code, which currently doesn't happen automatically. As a result, the default dispatches to a soup of primitive ops that XLA can't optimize into a single FlashAttention kernel.

I read it when I opened a discussion on the attention functionality last to last month. Still missed it. Sigh!

@bchetioui
Copy link
Collaborator

bchetioui commented Oct 29, 2024

So, after updating my system to (roughly) match your driver version, I was able to run with both FlashAttention enabled and loop unrolling. This gets me to

$ python3 jax_single_gpu.py input.txt 
total desired batch size: 524288
=> calculated gradient accumulation steps: 64
loaded 338025 tokens
1 epoch = 41 batches

Loading GPT2 model...
Number of parameters in the model       : 124.48 M
Training...

step    0 | loss: 10.890341 | dt: 180510.78ms | tok/sec: 2904.47
step    1 | loss: 10.333479 | dt: 4005.35ms | tok/sec: 130897.06
step    2 | loss: 9.422725 | dt: 4005.06ms | tok/sec: 130906.46
step    3 | loss: 9.101675 | dt: 3941.68ms | tok/sec: 133011.33
step    4 | loss: 8.898125 | dt: 3960.57ms | tok/sec: 132376.92
step    5 | loss: 8.920149 | dt: 3912.61ms | tok/sec: 133999.58
step    6 | loss: 8.580977 | dt: 3992.61ms | tok/sec: 131314.64
step    7 | loss: 8.543740 | dt: 3919.76ms | tok/sec: 133755.25
...

Using FlashAttention correctly unsurprisingly turns out to be very important here :) Now, we're closer to a reasonable time, with 130k tokens per second on JAX, and 150k on Torch. Still, we should be able to do better, I'll keep looking.

New nvidia-smi output for reference:

Tue Oct 29 10:45:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   32C    P0             44W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

@bchetioui
Copy link
Collaborator

We can glean another few tokens/s by rewriting the code to use a different path for gradient accumulation vs for actual model update:

+@eqx.filter_jit(donate='all-except-first')
+def train_step_accum(model, optim, optim_state, data, targets):
+  _, grads = compute_loss(model, data, targets)
+  _, optim_state = optim.update(
+    grads, optim_state, eqx.filter(model, eqx.is_array))
+  return optim_state
+
 def main(data_file_path):
     total_batch_size = 524288  # 2**19, ~0.5M, in number of tokens
     B = 8  # micro batch size
@@ -433,13 +440,22 @@ def main(data_file_path):
         t0 = time.time()
         for micro_step in range(grad_accum_steps):
             batch_inputs, batch_targets = train_loader.next_batch()
-            loss, model, optim_state = train_step(
+            if micro_step < grad_accum_steps - 1:
+              optim_state = train_step_accum(
+                model,
+                optim,
+                optim_state,
+                batch_inputs,
+                batch_targets,
+              )
+            else:
+              loss, model, optim_state = train_step(
                 model,
                 optim,
                 optim_state,
                 batch_inputs,
                 batch_targets,
-            )
+              )

This avoids making no-op updates to the model & rewriting all the weights to HBM at each accumulation step.

This gets us to ~140k tokens/second:

$ python3 jax_single_gpu.py input.txt 
total desired batch size: 524288
=> calculated gradient accumulation steps: 64
loaded 338025 tokens
1 epoch = 41 batches

Loading GPT2 model...
Number of parameters in the model       : 124.48 M
Training...

step    0 | loss: 10.890341 | dt: 243491.99ms | tok/sec: 2153.20
step    1 | loss: 10.333462 | dt: 3788.36ms | tok/sec: 138394.47
step    2 | loss: 9.422718 | dt: 3760.61ms | tok/sec: 139415.82
step    3 | loss: 9.101670 | dt: 3719.87ms | tok/sec: 140942.73
step    4 | loss: 8.898199 | dt: 3740.94ms | tok/sec: 140148.87
step    5 | loss: 8.920159 | dt: 3683.78ms | tok/sec: 142323.37
step    6 | loss: 8.580997 | dt: 3713.13ms | tok/sec: 141198.42
step    7 | loss: 8.543747 | dt: 3693.19ms | tok/sec: 141960.73
...

@AakashKumarNain
Copy link
Author

Using FlashAttention correctly unsurprisingly turns out to be very important here :) Now, we're closer to a reasonable time, with 130k tokens per second on JAX, and 150k on Torch. Still, we should be able to do better, I'll keep looking.

Thanks. I am able to reproduce this number on my side. One question though: Right now, we are using unroll=1000, but finding the ideal number for unrolling seems a tedious task. Is there any better way to do it? May be through HLO graphs?

This avoids making no-op updates to the model & rewriting all the weights to HBM at each accumulation step.

This is is nitty detail. We should add documentation for this in optax

@bchetioui
Copy link
Collaborator

Thanks. I am able to reproduce this number on my side. One question though: Right now, we are using unroll=1000, but finding the ideal number for unrolling seems a tedious task. Is there any better way to do it? May be through HLO graphs?

Awesome that you could reproduce it! I hacked around, but the goal is to unroll all the layers here---setting unroll=True should work. (Side note: there seems to be an off-by-one in the implementation of the scan lowering such that setting unroll=True lowers slightly less efficiently. I'll make a fix, but a temporary workaround is to set unroll=self.n_layer + 1).

This is is nitty detail. We should add documentation for this in optax

Agreed! Where do you think that would belong exactly? (This is my first time using optax :))

Another (marginal) improvement is to pre-process the data in the data loader to avoid doing a copy at each loop step (call to jnp.reshape in train_loader.next_batch() produce two copies). This is not strictly a win in terms of time because some of that compute happens before the training loop. Nevertheless, it gives a slightly more faithful comparison with Torch, since in the Torch version we get a view of the data (as opposed to a copy).

$ python3 jax_single_gpu.py input.txt 
total desired batch size: 524288
=> calculated gradient accumulation steps: 64
loaded 335872 tokens
1 epoch = 41 batches

Loading GPT2 model...
Number of parameters in the model       : 124.48 M
Training...

step    0 | loss: 10.890341 | dt: 241500.91ms | tok/sec: 2170.96
step    1 | loss: 10.333393 | dt: 3708.82ms | tok/sec: 141362.62
step    2 | loss: 9.422662 | dt: 3680.90ms | tok/sec: 142434.74
step    3 | loss: 9.101698 | dt: 3638.58ms | tok/sec: 144091.44
step    4 | loss: 8.898143 | dt: 3679.82ms | tok/sec: 142476.38
step    5 | loss: 8.920143 | dt: 3629.01ms | tok/sec: 144471.16
step    6 | loss: 8.580982 | dt: 3639.25ms | tok/sec: 144064.72
step    7 | loss: 8.543739 | dt: 3639.98ms | tok/sec: 144035.7
...

@AakashKumarNain
Copy link
Author

but the goal is to unroll all the layers here---setting unroll=True should work. (Side note: there seems to be an off-by-one in the implementation of the scan lowering such that setting unroll=True lowers slightly less efficiently. I'll make a fix, but a temporary workaround is to set unroll=self.n_layer + 1).

Yeah, when I used unroll=True, I noticed that difference. Thanks for the explanation.

Agreed! Where do you think that would belong exactly? (This is my first time using optax :))

This should go right here. Let me know if you need more info or any help on this.

Another (marginal) improvement is to pre-process the data in the data loader to avoid doing a copy at each loop step (call to jnp.reshape in train_loader.next_batch() produce two copies). This is not strictly a win in terms of time because some of that compute happens before the training loop. Nevertheless, it gives a slightly more faithful comparison with Torch, since in the Torch version we get a view of the data (as opposed to a copy).

Yeah we can do that. I just left it raw for now

@bchetioui
Copy link
Collaborator

Yeah, when I used unroll=True, I noticed that difference. Thanks for the explanation.

This commit should fix the issue at HEAD!

This should go right here. Let me know if you need more info or any help on this.

Actually, if you have a good idea of what useful documentation for you would have looked like, I'd love to see it! (Maybe you even want to send a PR?! :))

@AakashKumarNain
Copy link
Author

Actually, if you have a good idea of what useful documentation for you would have looked like, I'd love to see it! (Maybe you even want to send a PR?! :))

Cool. Leave that to me. I will make that PR both in optax and Equinox

@bchetioui
Copy link
Collaborator

I didn't have a lot of bandwidth to go further this week, but I recalled that there are a few improvements in XLA that are not yet in the last release---so wanted to share at least that. Here are some numbers with the latest nightlies (jax-0.4.36.dev20241031, jax-cuda12-pjrt-0.4.36.dev20241101, jax-cuda12-plugin-0.4.36.dev20241101, jaxlib-0.4.36.dev20241101).

$ python3 jax_single_gpu.py input.txt 
total desired batch size: 524288
=> calculated gradient accumulation steps: 64
loaded 335872 tokens
1 epoch = 41 batches

Loading GPT2 model...
Number of parameters in the model       : 124.48 M
Training...

step    0 | loss: 10.890312 | dt: 245581.22ms | tok/sec: 2134.89
step    1 | loss: 10.333449 | dt: 3611.62ms | tok/sec: 145167.12
step    2 | loss: 9.422793 | dt: 3563.81ms | tok/sec: 147114.48
step    3 | loss: 9.101702 | dt: 3515.64ms | tok/sec: 149130.17
step    4 | loss: 8.898161 | dt: 3545.84ms | tok/sec: 147860.06
step    5 | loss: 8.920076 | dt: 3497.69ms | tok/sec: 149895.64
step    6 | loss: 8.580976 | dt: 3499.92ms | tok/sec: 149800.12
step    7 | loss: 8.543713 | dt: 3501.90ms | tok/sec: 149715.49
...

So we're at roughly ~150k tok/sec for now.

@AakashKumarNain
Copy link
Author

I didn't have a lot of bandwidth to go further this week

No worries @bchetioui We are already close to the optimal number. I am hoping this exercise would turn out extremely critical for both of us in term of performance benchmarks.

@AakashKumarNain
Copy link
Author

@bchetioui were you able to find any more bottlenecks?

@bchetioui
Copy link
Collaborator

bchetioui commented Nov 8, 2024

Hi @AakashKumarNain; I have a few guesses of where to look at, but haven't looked into the profiles deeply just yet. At this point, it seems like we need to fix XLA performance. With that said, we can still improve the model at the JAX level directly.

We know that LayerNorm is currently pretty slow in XLA (though there is ongoing work to fix this). One easy way to check whether speeding that up makes a difference is to plug in the Pallas layer_norm kernel:

+from jax.experimental.pallas.ops.gpu import layer_norm as pallas_layer_norm
 
 ######################## Equinox model utils ################################
 
 
+def layer_norm_wrapper(self, x, *args, **kwargs):
+  x = jnp.reshape(x, (1, *x.shape))
+  res, *_ = pallas_layer_norm.layer_norm(x, self.weight, self.bias)
+  return jnp.reshape(res, x.shape[1:])
+
+eqx.nn.LayerNorm.__call__ = layer_norm_wrapper
+
+
 def is_layer(x):
     """Check if the current pytree is an instance of any Equinox layer."""
     return isinstance(x, (eqx.nn.Linear, eqx.nn.Embedding, eqx.nn.LayerNorm))
@@ -214,9 +223,9 @@ class TransformerBlock(eqx.Module):
         self.mlp = MLP(config, key=key2, dtype=dtype)
 
     def __call__(self, x, mask=None):
-        x = eqx.filter_vmap(self.norm_1)(x)
+        x = self.norm_1(x)
         x = x + self.attn(x, mask=mask)
-        x = eqx.filter_vmap(self.norm_2)(x)
+        x = self.norm_2(x)
         x = x + self.mlp(x)
         return x
 
@@ -298,7 +307,7 @@ class GPT(eqx.Module):
                                          unroll=self.n_layer + 1)
 
         # 4. Final pre-layer norm
-        x = eqx.filter_vmap(self.norm)(x).astype(jnp.bfloat16)
+        x = self.norm(x).astype(jnp.bfloat16)
 
         # 5. Classification head
         logits = eqx.filter_vmap(lm_head)(x

(Note, the plugging in is pretty hacky, the Pallas op may deserve some love :-))

The hypothesis is validated, since plugging that in gets us to roughly 156k tokens/second:

$ python3 jax_single_gpu.py input.txt                                                                                                                                                                                                                                                                                                                                                                                                      
total desired batch size: 524288                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
=> calculated gradient accumulation steps: 64                                                                                                                                                                                                                                                                                                                                                                                                                                                            
loaded 335872 tokens                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
1 epoch = 41 batches                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
Loading GPT2 model...                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
Number of parameters in the model       : 124.48 M                                                                                                                                                                                                                                                                                                                                                                                                                                                       
Training...                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
step    0 | loss: 10.890168 | dt: 233523.51ms | tok/sec: 2245.12                                                                                                                                                                                                                                                                                                                                                                                                                                         
step    1 | loss: 10.333891 | dt: 3537.08ms | tok/sec: 148226.20                                                                                                                                                                                                                                                                                                                                                                                                                                         
step    2 | loss: 9.422873 | dt: 3468.28ms | tok/sec: 151166.62                                                                                                                                                                                                                                                                                                                                                                                                                                          
step    3 | loss: 9.101831 | dt: 3384.95ms | tok/sec: 154887.89                                                                                                                                                                                                                                                                                                                                                                                                                                          
step    4 | loss: 8.898197 | dt: 3390.02ms | tok/sec: 154656.18                                                                                                                                                                                                                                                                                                                                                                                                                                          
step    5 | loss: 8.920005 | dt: 3361.72ms | tok/sec: 155958.28                                                                                                                                                                                                                                                                                                                                                                                                                                          
step    6 | loss: 8.580868 | dt: 3357.83ms | tok/sec: 156138.95                                                                                                                                                                                                                                                                                                                                                                                                                                          
step    7 | loss: 8.543556 | dt: 3400.50ms | tok/sec: 154179.69 
...

There were previous reports of Adam slowness; I'd like to poke at this and see whether there are any low hanging fruits there, but I have yet to find cycles to do this.

@AakashKumarNain
Copy link
Author

We know that LayerNorm is currently pretty slow in XLA

Can yo point out to me the issue/discussion/documentation where it was found and discussed? I would like to know the reason behind it because this may affect other ops as well in that case

but I have yet to find cycles to do this.

No worries. Thanks for all the help and the support you have provided. This is already looking good. In an ideal world, JAX should be extremely faster than torch. And in case of TPUs it is true, but not for GPUs. The problem is that GPUs are more widespread compared to TPUs, and if it is slow on GPU, no matter how hard people (like you and I) try to convince that JAX is better, it will not be enough

@bchetioui
Copy link
Collaborator

bchetioui commented Nov 11, 2024

Can yo point out to me the issue/discussion/documentation where it was found and discussed? I would like to know the reason behind it because this may affect other ops as well in that case

Unfortunately, I'm not aware of any public-facing bug---only internal ones. What it boils down to, however, is that normalizations (any of them, including e.g. Softmax) end up lowering to several kernels by default, as opposed to a single one. We (and I personally) have done significant work towards fixing this, but it's not yet on by default, and doesn't actually match everything that it should.

No worries. Thanks for all the help and the support you have provided. This is already looking good. In an ideal world, JAX should be extremely faster than torch. And in case of TPUs it is true, but not for GPUs. The problem is that GPUs are more widespread compared to TPUs, and if it is slow on GPU, no matter how hard people (like you and I) try to convince that JAX is better, it will not be enough

Happy to help! We're always working on making our stack better, and reports such as yours are very helpful to figure out what the sharp bits are for our external users, and to find out things we can do better. So thank you for filing this!

@AakashKumarNain
Copy link
Author

Unfortunately, I'm not aware of any public-facing bug---only internal ones. What it boils down to, however, is that normalizations (any of them, including e.g. Softmax) end up lowering to several kernels by default, as opposed to a single one. We (and I personally) have done significant work towards fixing this, but it's not yet on by default, and doesn't actually match everything that it should.

Oh okay. Any approximate timeline for the fixes that are already done but not in the dev branch?

Happy to help! We're always working on making our stack better, and reports such as yours are very helpful to figure out what the sharp bits are for our external users, and to find out things we can do better. So thank you for filing this!

No worries. Always happy to help make things better 🍻

@bchetioui
Copy link
Collaborator

Oh okay. Any approximate timeline for the fixes that are already done but not in the dev branch?

The changes that are done are all in the dev branch! And in fact, you can trigger the experimental rewriter using the --xla_gpu_experimental_enable_triton_softmax_priority_fusion=true (to be set in the XLA_FLAGS environment variable); but this doesn't work well for this workload just yet unfortunately, which is why I didn't advertise it so far. (Note, that flag is mostly for development, and may get deleted without warning---though that'll probably mean that we enabled the feature by default!)

@AakashKumarNain
Copy link
Author

Oh I have tried those, and it didn't work that time. Here is the corresponding issue: openxla/xla#17103

@AakashKumarNain
Copy link
Author

@bchetioui did you find anything else that can speed up the runs?

@bchetioui
Copy link
Collaborator

Hi @AakashKumarNain; I didn't yet investigate further for this particular model.

In general, we have a lot of work in the pipeline that should have a generally good effect on such models, but I haven't checked how they'll apply specifically to this benchmark.

@AakashKumarNain
Copy link
Author

In general, we have a lot of work in the pipeline that should have a generally good effect on such models, but I haven't checked how they'll apply specifically to this benchmark.

No worries. I can do that. Please let me know once these improvements are in the nightly or a stable version. I can restart the benchmarking

@bchetioui
Copy link
Collaborator

No worries. I can do that. Please let me know once these improvements are in the nightly or a stable version. I can restart the benchmarking.

I will keep you updated. The next big thing will hopefully be a fix for loop unrolling, which should allow us to reclaim compile-time and get the same performance as with unrolling scan out of the box without actually unrolling.

@carlosgmartin
Copy link
Contributor

@bchetioui Is there a specific place where that's being discussed or tracked? I created the following related posts:

@bchetioui
Copy link
Collaborator

bchetioui commented Feb 11, 2025

Hi @carlosgmartin, I do not think we have public bugs.

When it comes to handling scan better, PRs have been steadily reviewed and landed (look for PRs tagged with [ds-fusion] on the XLA repository). We should not be too far off.

Regarding openxla/stablehlo#2664, it may in fact not be necessary to add a new op to avoid the round trip. If the whole loop body can fit in a single kernel, and the number of iterations is known (which we already can annotate in the compiler), we could simply lower to a loop on device.

Do you have a concrete (real) motivating example where the roundtrip to host is problematic?

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Feb 11, 2025

@bchetioui

If the whole loop body can fit in a single kernel, and the number of iterations is known (which we already can annotate in the compiler), we could simply lower to a loop on device.

That sounds like what I had in mind. Keeping all computation on the device without having to return control to the host on each iteration, while avoiding the long compilation time from naive unrolling.

Could you clarify what you mean exactly by "roundtrip to host"?

@bchetioui
Copy link
Collaborator

Could you clarify what you mean exactly by "roundtrip to host"?

Returning control to the host on each iteration :)

@carlosgmartin
Copy link
Contributor

My understanding is that returning control to the host on each iteration (GPU ↔ CPU) is responsible for most of the slowdown that occurs at runtime when changing from unroll=True to unroll=1. Is that not accurate? Per this comment:

I believe that the scan implementation on GPU requires host synchronization at every step, which adds additional overhead that scales as the number of steps.

In my own work, unroll=True is significantly faster than unroll=1, with the disadvantage of a potentially very long compilation time. That's why I proposed the approach in the aforementioned links.

@bchetioui
Copy link
Collaborator

My understanding is that returning control to the host on each iteration (GPU ↔ CPU) is responsible for most of the slowdown that occurs at runtime when changing from unroll=True to unroll=1. Is that not accurate? Per #10794 (comment):

There is some overhead for sure, but I haven't quantified how much. The bulk of the overhead today comes from writing outputs to intermediate buffers before inserting them into the loop outputs, which is what the [ds-fusion] changes I mentioned previously are working to fix. They essentially will allow writing directly into the loop outputs.

When you unroll, that issue doesn't occur :)

@AakashKumarNain
Copy link
Author

No worries. I can do that. Please let me know once these improvements are in the nightly or a stable version. I can restart the benchmarking.

I will keep you updated. The next big thing will hopefully be a fix for loop unrolling, which should allow us to reclaim compile-time and get the same performance as with unrolling scan out of the box without actually unrolling.

Hi @bchetioui Are the changes merged (including the experimental pallas kernel for norm layers)? If so, I can start the benchmarking again

@bchetioui
Copy link
Collaborator

Hi @AakashKumarNain,

The PR to enable the fix for loop unrolling is here, but it is currently gated on fixing some broken tests. Hopefully it can be made available soon!

including the experimental pallas kernel for norm layers

I think you're asking about the XLA lowering for norm layers, which has already landed a while back---but which, as I recall, it may need a little more massaging to apply on your specific workload. (Anyway, this should already be reflected in the performance of a run-by-default.)

@AakashKumarNain
Copy link
Author

Thanks @bchetioui I will wait a little more before doing this again then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working performance make things lean and fast
Projects
None yet
Development

No branches or pull requests

5 participants