feat(inference): fp32 lm_head via native bf16xbf16 -> fp32 mm (alt to #2438)#2441
Merged
Conversation
Alternative implementation of fp32 lm_head proposed by @Jackmin801 on PR #2438: use torch.mm(..., out_dtype=torch.float32) (PyTorch >= 2.10) so the matmul accumulates and emits fp32 directly without zero-padding the bf16 operands or maintaining a separate fp32 weight copy. The win is purely in the epilogue: F.linear(bf16, bf16) accumulates in fp32 internally on tensor cores but truncates the output to bf16 before returning. Native out_dtype=fp32 keeps the full mantissa for the downstream softmax/sample. Removes the cached fp32 weight + tensor-version invalidation logic from PR #2438 — not needed when there's no separate fp32 copy. Wired identically to PR #2438: inference.enable_fp32_lm_head config flag -> additional_config["fp32_lm_head"] -> captured on LogitsProcessor.__init__ -> per-instance flag read in _get_logits. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Alternative implementation of
inference.enable_fp32_lm_head(sibling of #2438), following @Jackmin801's suggestion: usetorch.mm(hidden_states, lm_head.weight.T, out_dtype=torch.float32)(PyTorch >= 2.10) instead of explicitly upcasting both operands.Why this version
@Jackmin801 pointed out on #2438 that the operand-upcast approach (
hidden_states.float() @ lm_head.weight.float().T) only differs from the bf16 path in the epilogue — bf16 weights cast to fp32 are zero-padded mantissa, no extra precision in the operands; the actual win is keeping the fp32 GEMM output instead of truncating to bf16. Nativeout_dtype=fp32does that directly with bf16 operands and a fp32 accumulator/epilogue.Trade-offs vs #2438 (operand upcast + cache):
lm_head.weightin place; feat(inference): fp32 lm_head for stable logprobs under FP8/bf16 #2438 needsTensor._versionchecks to avoid serving stale step-0 weights — see bugbot comment)torch.mmout_dtypekwargSame wiring as #2438:
inference.enable_fp32_lm_head->additional_config["fp32_lm_head"]-> captured once onLogitsProcessor.__init__-> per-instance flag read in_get_logits.Usage
Notes
out_dtype, our patch becomes a single-line removal.lm_head.weight, so tied configs are safe (just slower because we still uselm_head.weightregardless of whether it aliasesembed_tokens).🤖 Generated with Claude Code
Note
Medium Risk
Touches vLLM logits computation via monkey-patching, which can affect numerical behavior and performance across all inference paths when enabled. Gated behind a config flag but depends on PyTorch support for
torch.mm(..., out_dtype=...).Overview
Adds an opt-in
inference.enable_fp32_lm_headflag that threads through to vLLM viaadditional_config["fp32_lm_head"].When enabled, installs a worker-side monkey patch on vLLM’s
LogitsProcessorto compute thelm_headmatmul withtorch.mm(..., out_dtype=torch.float32)(bf16×bf16 -> fp32) to improve logprob stability under lower-precision inference, while leaving the default path unchanged when the flag is off.Reviewed by Cursor Bugbot for commit 7c9cfe4. Bugbot is set up for automated code reviews on this repo. Configure here.