feat(inference): fp32 lm_head for stable logprobs under FP8/bf16#2438
Open
samsja wants to merge 1 commit into
Open
feat(inference): fp32 lm_head for stable logprobs under FP8/bf16#2438samsja wants to merge 1 commit into
samsja wants to merge 1 commit into
Conversation
2984e45 to
3efcdd8
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit b09816b. Configure here.
Adds inference.enable_fp32_lm_head to run the lm_head projection in fp32 instead of the model dtype. Closes the rollout-vs-trainer logprob gap that shows up in low-precision RL setups where the head matmul is the only remaining bf16 step before the sampler (which already promotes to fp32). Mirrors SGLang's --enable-fp32-lm-head and tracks vllm-project/vllm#24567, restricted here to the untied-embedding case. Wired through vLLM's additional_config dict so workers read the flag from get_current_vllm_config() — no env var. Lazily caches lm_head.weight.float() on the layer so the cast happens once. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Jackmin801
reviewed
May 8, 2026
Comment on lines
+1131
to
+1166
| def _patched_init(self, *args, **kwargs): | ||
| _original_init(self, *args, **kwargs) | ||
| vllm_config = get_current_vllm_config() | ||
| # `or {}` defends against downstream code overriding additional_config to None | ||
| # (default is an empty dict, but defensive copying is cheap). | ||
| additional_config = vllm_config.additional_config or {} | ||
| self._fp32_lm_head_enabled = additional_config.get("fp32_lm_head", False) | ||
| if self._fp32_lm_head_enabled: | ||
| if getattr(vllm_config.model_config.hf_config, "tie_word_embeddings", False): | ||
| raise NotImplementedError( | ||
| "fp32_lm_head is not supported for models with tie_word_embeddings=True; " | ||
| "casting lm_head.weight in place would also cast embed_tokens. Use a model " | ||
| "with untied embeddings, or extend this patch to keep a separate fp32 copy." | ||
| ) | ||
| logger.warning("fp32 lm_head ENABLED for this LogitsProcessor instance.") | ||
|
|
||
| def _patched_get_logits(self, hidden_states, lm_head, embedding_bias): | ||
| if not getattr(self, "_fp32_lm_head_enabled", False): | ||
| return _original_get_logits(self, hidden_states, lm_head, embedding_bias) | ||
|
|
||
| # Cache the fp32 weight, but invalidate it whenever lm_head.weight has | ||
| # been mutated. RL weight broadcasts (NCCL / filesystem) update params | ||
| # in place via param.copy_(...), which bumps Tensor._version. Detect | ||
| # that and rebuild — otherwise we'd serve stale step-0 weights forever. | ||
| weight = lm_head.weight | ||
| cur_version = weight._version | ||
| cached = getattr(lm_head, "_fp32_lm_head_weight", None) | ||
| cached_version = getattr(lm_head, "_fp32_lm_head_weight_version", None) | ||
| if cached is None or cached_version != cur_version: | ||
| with torch.no_grad(): | ||
| cached = weight.detach().to(torch.float32) | ||
| lm_head._fp32_lm_head_weight = cached | ||
| lm_head._fp32_lm_head_weight_version = cur_version | ||
|
|
||
| bias_fp32 = embedding_bias.to(torch.float32) if embedding_bias is not None else None | ||
| logits = F.linear(hidden_states.to(torch.float32), cached, bias_fp32) |
Member
There was a problem hiding this comment.
I think the only thing we gain here is removing the epilogue bf16 downcast. we should be able to replace this with torch.mm(hidden_states, lm_head.weight.T, out_dtype=torch.float32) and do this more efficiently
samsja
added a commit
that referenced
this pull request
May 8, 2026
Alternative implementation of fp32 lm_head proposed by @Jackmin801 on PR #2438: use torch.mm(..., out_dtype=torch.float32) (PyTorch >= 2.10) so the matmul accumulates and emits fp32 directly without zero-padding the bf16 operands or maintaining a separate fp32 weight copy. The win is purely in the epilogue: F.linear(bf16, bf16) accumulates in fp32 internally on tensor cores but truncates the output to bf16 before returning. Native out_dtype=fp32 keeps the full mantissa for the downstream softmax/sample. Removes the cached fp32 weight + tensor-version invalidation logic from PR #2438 — not needed when there's no separate fp32 copy. Wired identically to PR #2438: inference.enable_fp32_lm_head config flag -> additional_config["fp32_lm_head"] -> captured on LogitsProcessor.__init__ -> per-instance flag read in _get_logits. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Summary
inference.enable_fp32_lm_headflag that runs the lm_head projection in fp32 instead of the model dtype (bf16). Bothhidden_statesandlm_head.weightare cast to float32 before the matmul.additional_configdict (no env vars). The launcher'sto_vllm()injects{"fp32_lm_head": True}when the flag is set.LogitsProcessor.__init__, onto the instance. The per-forward check then just readsself._fp32_lm_head_enabled. This matters becauseget_current_vllm_config()only resolves during model init — calling it from_get_logitsduring serving raisesAssertionError(we verified empirically that the per-call approach silently no-ops).Tensor._versionso RL weight broadcasts (NCCL / filesystemparam.copy_(...)) don't leave a stale step-0 fp32 copy serving forever.tie_word_embeddings=Truesince castinglm_head.weightin place would also castembed_tokens.Why
Under FP8/bf16 inference the only step between hidden states and the sampler that's not fp32 is the lm_head matmul itself — vLLM's sampler already promotes to fp32 before softmax/sampling. Casting the operands to fp32 doesn't add precision to the bf16 weight values themselves (zero-padded mantissa), but it does keep the matmul output in fp32 instead of truncating to bf16 — which is what the sampler actually needs. SGLang implements the same flag the same way (
logits_processor.py:894-897in their tree).Closes the rollout-vs-trainer logprob gap that destabilizes RL importance sampling. Background: MiniMax-M1 §7.6, ScaleRL, Meta arxiv:2510.13786. Tracks vllm-project/vllm#24567, restricted here to the untied-embeddings case.
Validation
Two single-node SLURM jobs on Qwen3-8B-Base + hendrycks-math, identical apart from
enable_fp32_lm_head. The fp32 run shows a measurably lower train/inference logprob mismatch on wandb. Confirmed the patched path actually executes (vs silently falling through) by temporarily injecting1/0inside the fp32 branch — the worker crashed withZeroDivisionErroron the first forward, proving the init-time capture is correct.Usage
Notes
lm_head.weight(~600MB for Qwen3-30B vocab/hidden), rebuilt only when the underlying weight is mutated. The fp32 GEMM at the end is negligible vs the rest of the model.tie_word_embeddings=True— explicitNotImplementedError.vllm-project/vllm#24567(or a successor) lands upstream.🤖 Generated with Claude Code