-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX code is extremely slow on GPUs #24411
Comments
Thanks for filing! I've looked into this a bit, and I just wanted to check with you @AakashKumarNain: are the set ups in both files expected to be the same? Looking at the very first lines in
in
in I believe this means that the JAX model will be running twice as many I ran the JAX model on a single A100-40 (SXM) node, and here are the numbers I got: $ python3 jax_single_gpu.py input.txt
total desired batch size: 524288
=> calculated gradient accumulation steps: 64
loaded 338025 tokens
1 epoch = 41 batches
Loading GPT2 model...
Number of parameters in the model : 124.48 M
Training...
step 0 | loss: 10.890376 | dt: 27241.55ms | tok/sec: 19245.89
step 1 | loss: 10.333470 | dt: 10384.95ms | tok/sec: 50485.36
step 2 | loss: 9.422712 | dt: 10379.23ms | tok/sec: 50513.18
... which are different from the initial reported numbers. Is the configuration expected to be the same in the repo as in your screenshots? (I wasn't yet able to run the PyTorch code for comparison due to dependency issues, but hope to get to that soon.) |
One suggested improvement! By fully unrolling the
(Unrolling |
Thanks for looking into it @bchetioui The benchmark is run for both the setup with a batch size of 8 as I couldn't fit a batch size of 16 with JAX (most probably due to the extra pytree that act as a mask for adamw). I might have forgotten to update the batch info in git. If you look at the number of batches and the number of grad accumulated steps printed on the screen, both are same meaning the micro batch size is same. Though for the sake of more transparency, I will update the repo, and post the numbers here once again Re unrolling: I am not worried about the compile time for now. Let me check that. Thanks for the pointers |
Updated the repo. Re unroll: I did it, and even though the I got a nice bump in the number of tokens processed per second, the time taken to process those tokens got extremely worse. I would have rather paid one time compilation penalty. Also, the performance is still nowhere close to torch neither in terms of tokens (163K token processed in torch vs 22K tokens processed in JAX) nor in terms of time taken |
Thanks for trying and for the feedback!
Can you explain what that means? Presumably, if we're processing more tokens per second, then the time has to be better---or am I missing something? If you're talking about the first step, then that is because the first step includes compilation overhead. (The first call to the function triggers compilation.)
Can you give some details about the hardware you are running on? You said you are running on A100-40, but it seems that as posted above I am getting ~80K tokens processed in JAX on my side, running on an I am facing the following error when running
Do you have any idea where that might be coming from? Is there perhaps a requirement missing in |
Please take a look at the numbers below: ![]()
Sure thing. Here is the output: ![]()
No nothing missing from the One question for you: How are you unrolling the scan? Maybe we are getting different numbers because of that? |
Thank you for sharing! The chip is the same, so that's not the difference. One difference worth noting (though not sure if it matters) is that I'm using an older CUDA version than you are:
Thanks, seems like retrying with a new env worked after a couple of tries! (There were some issues due to incomplete installs of This difference is still very much unexpected of course, so I'll keep digging :)
I just added the (x, layer_idx), _ = jax.lax.scan(f, (x, layer_idx), dynamic_layers, unroll=1000) # unrolling to 12 should suffice, I think One thing I also noticed is that you're not actually calling into a attn = jax.nn.dot_product_attention(q, k, v, is_causal=True) could be rewritten into attn = jax.nn.dot_product_attention(q, k, v, is_causal=True, implementation="cudnn") for that. There is a |
Good to know. Yeah, sometimes it happens if there is a silent package conflict or a corrupted install from the cache
Yes, I was shocked to see the performance difference between the two.
I did the same but with a much lower number, and you are right 12 per iteration should be enough. Will check that
I read it when I opened a discussion on the attention functionality last to last month. Still missed it. Sigh! |
So, after updating my system to (roughly) match your driver version, I was able to run with both
Using New
|
We can glean another few tokens/s by rewriting the code to use a different path for gradient accumulation vs for actual model update: +@eqx.filter_jit(donate='all-except-first')
+def train_step_accum(model, optim, optim_state, data, targets):
+ _, grads = compute_loss(model, data, targets)
+ _, optim_state = optim.update(
+ grads, optim_state, eqx.filter(model, eqx.is_array))
+ return optim_state
+
def main(data_file_path):
total_batch_size = 524288 # 2**19, ~0.5M, in number of tokens
B = 8 # micro batch size
@@ -433,13 +440,22 @@ def main(data_file_path):
t0 = time.time()
for micro_step in range(grad_accum_steps):
batch_inputs, batch_targets = train_loader.next_batch()
- loss, model, optim_state = train_step(
+ if micro_step < grad_accum_steps - 1:
+ optim_state = train_step_accum(
+ model,
+ optim,
+ optim_state,
+ batch_inputs,
+ batch_targets,
+ )
+ else:
+ loss, model, optim_state = train_step(
model,
optim,
optim_state,
batch_inputs,
batch_targets,
- )
+ ) This avoids making no-op updates to the model & rewriting all the weights to HBM at each accumulation step. This gets us to ~140k tokens/second:
|
Thanks. I am able to reproduce this number on my side. One question though: Right now, we are using
This is is nitty detail. We should add documentation for this in optax |
Awesome that you could reproduce it! I hacked around, but the goal is to unroll all the layers here---setting
Agreed! Where do you think that would belong exactly? (This is my first time using Another (marginal) improvement is to pre-process the data in the data loader to avoid doing a copy at each loop step (call to
|
Yeah, when I used
This should go right here. Let me know if you need more info or any help on this.
Yeah we can do that. I just left it raw for now |
This commit should fix the issue at HEAD!
Actually, if you have a good idea of what useful documentation for you would have looked like, I'd love to see it! (Maybe you even want to send a PR?! :)) |
Cool. Leave that to me. I will make that PR both in optax and Equinox |
I didn't have a lot of bandwidth to go further this week, but I recalled that there are a few improvements in XLA that are not yet in the last release---so wanted to share at least that. Here are some numbers with the latest nightlies (
So we're at roughly ~150k tok/sec for now. |
No worries @bchetioui We are already close to the optimal number. I am hoping this exercise would turn out extremely critical for both of us in term of performance benchmarks. |
@bchetioui were you able to find any more bottlenecks? |
Hi @AakashKumarNain; I have a few guesses of where to look at, but haven't looked into the profiles deeply just yet. At this point, it seems like we need to fix XLA performance. With that said, we can still improve the model at the JAX level directly. We know that +from jax.experimental.pallas.ops.gpu import layer_norm as pallas_layer_norm
######################## Equinox model utils ################################
+def layer_norm_wrapper(self, x, *args, **kwargs):
+ x = jnp.reshape(x, (1, *x.shape))
+ res, *_ = pallas_layer_norm.layer_norm(x, self.weight, self.bias)
+ return jnp.reshape(res, x.shape[1:])
+
+eqx.nn.LayerNorm.__call__ = layer_norm_wrapper
+
+
def is_layer(x):
"""Check if the current pytree is an instance of any Equinox layer."""
return isinstance(x, (eqx.nn.Linear, eqx.nn.Embedding, eqx.nn.LayerNorm))
@@ -214,9 +223,9 @@ class TransformerBlock(eqx.Module):
self.mlp = MLP(config, key=key2, dtype=dtype)
def __call__(self, x, mask=None):
- x = eqx.filter_vmap(self.norm_1)(x)
+ x = self.norm_1(x)
x = x + self.attn(x, mask=mask)
- x = eqx.filter_vmap(self.norm_2)(x)
+ x = self.norm_2(x)
x = x + self.mlp(x)
return x
@@ -298,7 +307,7 @@ class GPT(eqx.Module):
unroll=self.n_layer + 1)
# 4. Final pre-layer norm
- x = eqx.filter_vmap(self.norm)(x).astype(jnp.bfloat16)
+ x = self.norm(x).astype(jnp.bfloat16)
# 5. Classification head
logits = eqx.filter_vmap(lm_head)(x (Note, the plugging in is pretty hacky, the Pallas op may deserve some love :-)) The hypothesis is validated, since plugging that in gets us to roughly 156k tokens/second:
There were previous reports of Adam slowness; I'd like to poke at this and see whether there are any low hanging fruits there, but I have yet to find cycles to do this. |
Can yo point out to me the issue/discussion/documentation where it was found and discussed? I would like to know the reason behind it because this may affect other ops as well in that case
No worries. Thanks for all the help and the support you have provided. This is already looking good. In an ideal world, JAX should be extremely faster than torch. And in case of TPUs it is true, but not for GPUs. The problem is that GPUs are more widespread compared to TPUs, and if it is slow on GPU, no matter how hard people (like you and I) try to convince that JAX is better, it will not be enough |
Unfortunately, I'm not aware of any public-facing bug---only internal ones. What it boils down to, however, is that normalizations (any of them, including e.g. Softmax) end up lowering to several kernels by default, as opposed to a single one. We (and I personally) have done significant work towards fixing this, but it's not yet on by default, and doesn't actually match everything that it should.
Happy to help! We're always working on making our stack better, and reports such as yours are very helpful to figure out what the sharp bits are for our external users, and to find out things we can do better. So thank you for filing this! |
Oh okay. Any approximate timeline for the fixes that are already done but not in the dev branch?
No worries. Always happy to help make things better 🍻 |
The changes that are done are all in the dev branch! And in fact, you can trigger the experimental rewriter using the |
Oh I have tried those, and it didn't work that time. Here is the corresponding issue: openxla/xla#17103 |
@bchetioui did you find anything else that can speed up the runs? |
Hi @AakashKumarNain; I didn't yet investigate further for this particular model. In general, we have a lot of work in the pipeline that should have a generally good effect on such models, but I haven't checked how they'll apply specifically to this benchmark. |
No worries. I can do that. Please let me know once these improvements are in the nightly or a stable version. I can restart the benchmarking |
I will keep you updated. The next big thing will hopefully be a fix for loop unrolling, which should allow us to reclaim compile-time and get the same performance as with unrolling |
@bchetioui Is there a specific place where that's being discussed or tracked? I created the following related posts: |
Hi @carlosgmartin, I do not think we have public bugs. When it comes to handling Regarding openxla/stablehlo#2664, it may in fact not be necessary to add a new op to avoid the round trip. If the whole loop body can fit in a single kernel, and the number of iterations is known (which we already can annotate in the compiler), we could simply lower to a loop on device. Do you have a concrete (real) motivating example where the roundtrip to host is problematic? |
That sounds like what I had in mind. Keeping all computation on the device without having to return control to the host on each iteration, while avoiding the long compilation time from naive unrolling. Could you clarify what you mean exactly by "roundtrip to host"? |
Returning control to the host on each iteration :) |
My understanding is that returning control to the host on each iteration (GPU ↔ CPU) is responsible for most of the slowdown that occurs at runtime when changing from
In my own work, |
There is some overhead for sure, but I haven't quantified how much. The bulk of the overhead today comes from writing outputs to intermediate buffers before inserting them into the loop outputs, which is what the When you unroll, that issue doesn't occur :) |
Hi @bchetioui Are the changes merged (including the experimental pallas kernel for norm layers)? If so, I can start the benchmarking again |
Hi @AakashKumarNain, The PR to enable the fix for loop unrolling is here, but it is currently gated on fixing some broken tests. Hopefully it can be made available soon!
I think you're asking about the XLA lowering for norm layers, which has already landed a while back---but which, as I recall, it may need a little more massaging to apply on your specific workload. (Anyway, this should already be reflected in the performance of a run-by-default.) |
Thanks @bchetioui I will wait a little more before doing this again then |
Description
Last week a discussion took place on Twitter where many people noticed that the performance of JAX on GPUs is still subpar compared to PyTorch code. @mattjj and I had a discussion afterwards and we agreed that a repo with minimal codebase that showcase the differences in performance between the two would be great.
GPT-2 is a good model to start with, and here is a repo that contains code both for JAX, and PyTorch for the same model. The instructions provided on the repo are enough to download and run code locally (on a GPU machine).
On my side, these are the results I got on an A100 40G machine:
JAX
PyTorch
Compared to Torch, JAX is extremely slow here, and I have no idea why. There is a chance of a silent bug somewhere in the JAX code, and I may have overlooked it. Given the number of times I have been through this, I think a fresh set of eyes would do better justice. Please let me know if you need any other information on this from my side.
System info (python version, jaxlib version, accelerator, etc.)
Optax is installed from git because there was a fix for
adamw
but that was not the part of the last release.The text was updated successfully, but these errors were encountered: