-
Notifications
You must be signed in to change notification settings - Fork 335
Do not materialize full text/image/speech logits (10GB for 16k seq text). #1236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ds-hwang
wants to merge
1
commit into
main
Choose a base branch
from
scan_logit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@ruomingp Could you review it? From 1198 |
…xt). When the vocab size is large and the context length is long, full logits consume a significant amount of memory. For example, with a sequence length of 16k and a vocab size of 153k, a single batch in FP32 consumes about 9.6 GB of HBM. This is currently happening in the 3B model, and by eliminating this memory usage, we could reduce peak forward memory by roughly 1/3. Logits are mainly used to compute ids and cross entropy, but after these ops, we only need memory of shape [16k, 1]. So we can achieve the same outcome without materializing the full logits. This PR computes ids and cross entropy within a jax.lax.scan, evaluating every `scan_chunk` steps. When `scan_chunk=512`, peak memory drops from 16k × 153k (9.6GB) to 512 × 153k (313MB), a substantial reduction. Furthermore, since logits are the primary memory bottleneck for long sequences, this approach allows us to support longer contexts (e.g., 32k, 128k, 1M) while keeping logits memory fixed at `scan_chunk` tokens. **Change details:** * CausalLM._metrics() uses scan to compute logits in chunks, and the LossMetrics classes compute loss metrics on a per-chunk basis. After the scan, the metrics across num_iterations are aggregated. * To support this change, scan_in_context was updated to include an option that merges summaries across scan iterations before returning. **Why jax.lax.scan?** Among all tested options, jax.lax.scan(jax.remat) offered the best trade-off between memory and speed. Other alternatives either didn’t save much memory or were significantly slower. fori_loop was close in behavior but had slightly worse performance and does not support autodiff. | Method | HBM (GB) | Time (s) | |------------------------|----------|----------| | as-is | 18.1 | 0.56 | | scan | 24.0 | 0.90 | | **scan (remat)** | 14.4 | 0.62 | | scan (remat + jit) | 14.4 | 0.61 | | fori | 32.7 | 2.99 | | fori (remat) | 16.4 | 0.62 | | fori (remat + jit) | 16.4 | 0.63 | | lax.map | 42.0 | 2.93 | | jax.vmap | 24.0 | 0.57 | | jax.vmap (remat) | 24.0 | 0.57 | | jax.vmap (remat + jit)| 24.0 | 0.45 | **New `scan_chunk` config is introduced.** The default is None, which disables scanning. To enable scan, you must set a value. 512 is recommended for both TPU and GPU. For example, with a sequence length of 16k and a vocab size of 150k, this can save around 10 GB of memory, at the cost of about a 5% slowdown, according to following benchmark results. * v7/v8 3B Pretrain (seq_len=8k, FSDP=4, TPUv5p) * `scan_chunk 512` saves 2–3 GB of HBM with almost no impact on speed. | Configuration | Train HBM | Train sec/step | Forward HBM | Forward sec/step | |-----------------------|-----------|----------------|-------------|------------------| | As-is | 19.6 GB | 1.116 s | 8.43 GB | 296 ms | | scan_chunk 64 | 17.4 GB | 1.172 s | 5.59 GB | 295 ms | | scan_chunk 128 | 17.4 GB | 1.172 s | 5.59 GB | 291 ms | | scan_chunk 256 | 17.4 GB | 1.160 s | 5.59 GB | 290 ms | | **scan_chunk 512** | 17.4 GB | 1.161 s | 5.59 GB | 290 ms | | **scan_chunk 512** | - 11% | + 4% | - 44% | - 2% | | scan_chunk 1k | 17.6 GB | 1.161 s | 5.59 GB | 291 ms | | scan_chunk 2k | 18.6 GB | 1.158 s | 5.59 GB | 289 ms | * v7 3B + audio tokenizer Post-train (seq_len=16k, FSDP=4, TPUv5p) * `scan_chunk 512` saves 10 GB of HBM, but results in a 15–28% slowdown (likely due to audio embedding sharding, needs check). | Configuration | Train HBM | Train sec/step | Forward HBM | Forward sec/step | |-------------------|-----------|----------------|-------------|------------------| | As-is | 38.8 GB | 3.070 s | 19.8 GB | 0.899 s | | scan_chunk 64 | 30.7 GB | 3.530 s | 9.22 GB | 1.158 s | | scan_chunk 128 | 27.6 GB | 3.560 s | 9.22 GB | 1.096 s | | scan_chunk 256 | 27.6 GB | 3.510 s | 9.22 GB | 1.149 s | | **scan_chunk 512** | 27.6 GB | 3.530 s | 9.22 GB | 1.149 s | | **scan_chunk 512** | - 29% | + 15% | - 53% | + 28% | | scan_chunk 1k | 28.0 GB | 3.530 s | 9.23 GB | 1.152 s | | scan_chunk 2k | 30.7 GB | 3.530 s | 12.1 GB | 1.148 s | * v7/v8 3B Pretrain (seq_len=8k, FSDP=8, H100) * `scan_chunk 512` saves 3.7 GB of HBM with a 7% slowdown. | Configuration | Train HBM | Train sec/step | |---------------------|-----------|----------------| | As-is | 18.1 GB | 1.670 s | | scan_chunk 128 | 14.4 GB | 1.850 s | | scan_chunk 256 | 14.4 GB | 1.806 s | | **scan_chunk 512** | 14.5 GB | 1.786 s | | **scan_chunk 512** | - 20% | + 7% | | scan_chunk 1k | 14.7 GB | 1.769 s | | scan_chunk 2k | 15.3 GB | 1.764 s | | scan_chunk 8k | 24.1 GB | 1.729 s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
When the vocab size is large and the context length is long, full logits consume a significant amount of memory.
For example, with a sequence length of 16k and a vocab size of 153k, a single batch in FP32 consumes about 9.6 GB of HBM.
This is currently happening in the 3B model, and by eliminating this memory usage, we could reduce peak forward memory by roughly 1/3.
Logits are mainly used to compute ids and cross entropy, but after these ops, we only need memory of shape [16k, 1]. So we can achieve the same outcome without materializing the full logits.
This PR computes ids and cross entropy within a jax.lax.scan, evaluating every
scan_chunk
steps. Whenscan_chunk=512
, peak memory drops from 16k × 153k (9.6GB) to 512 × 153k (313MB), a substantial reduction.Furthermore, since logits are the primary memory bottleneck for long sequences, this approach allows us to support longer contexts (e.g., 32k, 128k, 1M) while keeping logits memory fixed at
scan_chunk
tokens.Change details:
Why jax.lax.scan?
Among all tested options, jax.lax.scan(jax.remat) offered the best trade-off between memory and speed. Other alternatives either didn’t save much memory or were significantly slower. fori_loop was close in behavior but had slightly worse performance and does not support autodiff.
New
scan_chunk
config is introduced.The default is None, which disables scanning.
To enable scan, you must set a value. 512 is recommended for both TPU and GPU. For example, with a sequence length of 16k and a vocab size of 150k, this can save around 10 GB of memory, at the cost of about a 5% slowdown, according to following benchmark results.
scan_chunk 512
saves 2–3 GB of HBM with almost no impact on speed.scan_chunk 512
saves 10 GB of HBM, but results in a 15–28% slowdown (likely due to audio embedding sharding, needs check).scan_chunk 512
saves 3.7 GB of HBM with a 7% slowdown.