Skip to content

Do not materialize full text/image/speech logits (10GB for 16k seq text). #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ds-hwang
Copy link
Contributor

@ds-hwang ds-hwang commented Jun 3, 2025

When the vocab size is large and the context length is long, full logits consume a significant amount of memory.
For example, with a sequence length of 16k and a vocab size of 153k, a single batch in FP32 consumes about 9.6 GB of HBM.
This is currently happening in the 3B model, and by eliminating this memory usage, we could reduce peak forward memory by roughly 1/3.

Logits are mainly used to compute ids and cross entropy, but after these ops, we only need memory of shape [16k, 1]. So we can achieve the same outcome without materializing the full logits.

This PR computes ids and cross entropy within a jax.lax.scan, evaluating every scan_chunk steps. When scan_chunk=512, peak memory drops from 16k × 153k (9.6GB) to 512 × 153k (313MB), a substantial reduction.

Furthermore, since logits are the primary memory bottleneck for long sequences, this approach allows us to support longer contexts (e.g., 32k, 128k, 1M) while keeping logits memory fixed at scan_chunk tokens.

Change details:

  • CausalLM._metrics() uses scan to compute logits in chunks, and the LossMetrics classes compute loss metrics on a per-chunk basis. After the scan, the metrics across num_iterations are aggregated.
  • To support this change, scan_in_context was updated to include an option that merges summaries across scan iterations before returning.

Why jax.lax.scan?
Among all tested options, jax.lax.scan(jax.remat) offered the best trade-off between memory and speed. Other alternatives either didn’t save much memory or were significantly slower. fori_loop was close in behavior but had slightly worse performance and does not support autodiff.

Method HBM (GB) Time (s)
as-is 18.1 0.56
scan 24.0 0.90
scan (remat) 14.4 0.62
scan (remat + jit) 14.4 0.61
fori 32.7 2.99
fori (remat) 16.4 0.62
fori (remat + jit) 16.4 0.63
lax.map 42.0 2.93
jax.vmap 24.0 0.57
jax.vmap (remat) 24.0 0.57
jax.vmap (remat + jit) 24.0 0.45

New scan_chunk config is introduced.
The default is None, which disables scanning.
To enable scan, you must set a value. 512 is recommended for both TPU and GPU. For example, with a sequence length of 16k and a vocab size of 150k, this can save around 10 GB of memory, at the cost of about a 5% slowdown, according to following benchmark results.

  • v7/v8 3B Pretrain (seq_len=8k, FSDP=4, TPUv5p)
    • scan_chunk 512 saves 2–3 GB of HBM with almost no impact on speed.
Configuration Train HBM Train sec/step Forward HBM Forward sec/step
As-is 19.6 GB 1.116 s 8.43 GB 296 ms
scan_chunk 64 17.4 GB 1.172 s 5.59 GB 295 ms
scan_chunk 128 17.4 GB 1.172 s 5.59 GB 291 ms
scan_chunk 256 17.4 GB 1.160 s 5.59 GB 290 ms
scan_chunk 512 17.4 GB 1.161 s 5.59 GB 290 ms
scan_chunk 512 - 11% + 4% - 44% - 2%
scan_chunk 1k 17.6 GB 1.161 s 5.59 GB 291 ms
scan_chunk 2k 18.6 GB 1.158 s 5.59 GB 289 ms
  • v7 3B + audio tokenizer Post-train (seq_len=16k, FSDP=4, TPUv5p)
    • scan_chunk 512 saves 10 GB of HBM, but results in a 15–28% slowdown (likely due to audio embedding sharding, needs check).
Configuration Train HBM Train sec/step Forward HBM Forward sec/step
As-is 38.8 GB 3.070 s 19.8 GB 0.899 s
scan_chunk 64 30.7 GB 3.530 s 9.22 GB 1.158 s
scan_chunk 128 27.6 GB 3.560 s 9.22 GB 1.096 s
scan_chunk 256 27.6 GB 3.510 s 9.22 GB 1.149 s
scan_chunk 512 27.6 GB 3.530 s 9.22 GB 1.149 s
scan_chunk 512 - 29% + 15% - 53% + 28%
scan_chunk 1k 28.0 GB 3.530 s 9.23 GB 1.152 s
scan_chunk 2k 30.7 GB 3.530 s 12.1 GB 1.148 s
  • v7/v8 3B Pretrain (seq_len=8k, FSDP=8, H100)
    • scan_chunk 512 saves 3.7 GB of HBM with a 7% slowdown.
Configuration Train HBM Train sec/step
As-is 18.1 GB 1.670 s
scan_chunk 128 14.4 GB 1.850 s
scan_chunk 256 14.4 GB 1.806 s
scan_chunk 512 14.5 GB 1.786 s
scan_chunk 512 - 20% + 7%
scan_chunk 1k 14.7 GB 1.769 s
scan_chunk 2k 15.3 GB 1.764 s
scan_chunk 8k 24.1 GB 1.729 s

@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners June 3, 2025 21:08
@ds-hwang
Copy link
Contributor Author

ds-hwang commented Jun 3, 2025

@ruomingp Could you review it? From 1198

…xt).

When the vocab size is large and the context length is long, full logits
consume a significant amount of memory.
For example, with a sequence length of 16k and a vocab size of 153k, a single
batch in FP32 consumes about 9.6 GB of HBM.
This is currently happening in the 3B model, and by eliminating this memory
usage, we could reduce peak forward memory by roughly 1/3.

Logits are mainly used to compute ids and cross entropy,
but after these ops, we only need memory of shape [16k, 1].
So we can achieve the same outcome without materializing the full logits.

This PR computes ids and cross entropy within a jax.lax.scan, evaluating every
`scan_chunk` steps. When `scan_chunk=512`, peak memory drops from 16k × 153k
(9.6GB) to 512 × 153k (313MB), a substantial reduction.

Furthermore, since logits are the primary memory bottleneck for long sequences,
this approach allows us to support longer contexts (e.g., 32k, 128k, 1M) while
keeping logits memory fixed at `scan_chunk` tokens.

**Change details:**
* CausalLM._metrics() uses scan to compute logits in chunks, and the
 LossMetrics classes compute loss metrics on a per-chunk basis. After the scan,
 the metrics across num_iterations are aggregated.
* To support this change, scan_in_context was updated to include an option that
 merges summaries across scan iterations before returning.

**Why jax.lax.scan?**
Among all tested options, jax.lax.scan(jax.remat) offered the best trade-off
between memory and speed. Other alternatives either didn’t save much memory or
were significantly slower. fori_loop was close in behavior but had slightly
worse performance and does not support autodiff.

| Method                 | HBM (GB) | Time (s) |
|------------------------|----------|----------|
| as-is                 | 18.1     | 0.56     |
| scan                  | 24.0     | 0.90     |
| **scan (remat)**      | 14.4     | 0.62     |
| scan (remat + jit)    | 14.4     | 0.61     |
| fori                  | 32.7     | 2.99     |
| fori (remat)          | 16.4     | 0.62     |
| fori (remat + jit)    | 16.4     | 0.63     |
| lax.map               | 42.0     | 2.93     |
| jax.vmap              | 24.0     | 0.57     |
| jax.vmap (remat)      | 24.0     | 0.57     |
| jax.vmap (remat + jit)| 24.0     | 0.45     |

**New `scan_chunk` config is introduced.**
The default is None, which disables scanning.
To enable scan, you must set a value. 512 is recommended for both TPU and GPU.
For example, with a sequence length of 16k and a vocab size of 150k, this can
save around 10 GB of memory, at the cost of about a 5% slowdown, according to
following benchmark results.

* v7/v8 3B Pretrain (seq_len=8k, FSDP=4, TPUv5p)
  * `scan_chunk 512` saves 2–3 GB of HBM with almost no impact on speed.

| Configuration         | Train HBM | Train sec/step | Forward HBM | Forward sec/step |
|-----------------------|-----------|----------------|-------------|------------------|
| As-is                 | 19.6 GB   | 1.116 s        | 8.43 GB     | 296 ms           |
| scan_chunk 64         | 17.4 GB   | 1.172 s        | 5.59 GB     | 295 ms           |
| scan_chunk 128        | 17.4 GB   | 1.172 s        | 5.59 GB     | 291 ms           |
| scan_chunk 256        | 17.4 GB   | 1.160 s        | 5.59 GB     | 290 ms           |
| **scan_chunk 512**        | 17.4 GB   | 1.161 s        | 5.59 GB     | 290 ms           |
| **scan_chunk 512**        | - 11%   | + 4%       | - 44%     | - 2%           |
| scan_chunk 1k         | 17.6 GB   | 1.161 s        | 5.59 GB     | 291 ms           |
| scan_chunk 2k         | 18.6 GB   | 1.158 s        | 5.59 GB     | 289 ms           |

* v7 3B + audio tokenizer Post-train (seq_len=16k, FSDP=4, TPUv5p)
  * `scan_chunk 512` saves 10 GB of HBM, but results in a 15–28% slowdown (likely due to audio embedding sharding, needs check).

| Configuration     | Train HBM | Train sec/step | Forward HBM | Forward sec/step |
|-------------------|-----------|----------------|-------------|------------------|
| As-is             | 38.8 GB   | 3.070 s        | 19.8 GB     | 0.899 s          |
| scan_chunk 64     | 30.7 GB   | 3.530 s        | 9.22 GB     | 1.158 s          |
| scan_chunk 128    | 27.6 GB   | 3.560 s        | 9.22 GB     | 1.096 s          |
| scan_chunk 256    | 27.6 GB   | 3.510 s        | 9.22 GB     | 1.149 s          |
| **scan_chunk 512**    | 27.6 GB   | 3.530 s        | 9.22 GB     | 1.149 s          |
| **scan_chunk 512**    | - 29%   | + 15%       | - 53%     | + 28%          |
| scan_chunk 1k     | 28.0 GB   | 3.530 s        | 9.23 GB     | 1.152 s          |
| scan_chunk 2k     | 30.7 GB   | 3.530 s        | 12.1 GB     | 1.148 s          |

* v7/v8 3B Pretrain (seq_len=8k, FSDP=8, H100)
  * `scan_chunk 512` saves 3.7 GB of HBM with a 7% slowdown.

| Configuration       | Train HBM | Train sec/step |
|---------------------|-----------|----------------|
| As-is               | 18.1 GB   | 1.670 s        |
| scan_chunk 128      | 14.4 GB   | 1.850 s        |
| scan_chunk 256      | 14.4 GB   | 1.806 s        |
| **scan_chunk 512**      | 14.5 GB   | 1.786 s        |
| **scan_chunk 512**      | - 20%   | + 7%        |
| scan_chunk 1k       | 14.7 GB   | 1.769 s        |
| scan_chunk 2k       | 15.3 GB   | 1.764 s        |
| scan_chunk 8k       | 24.1 GB   | 1.729 s        |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant