Skip to content

feat(inference): fp32 lm_head via native bf16xbf16 -> fp32 mm (alt to #2438)#2441

Merged
samsja merged 1 commit into
mainfrom
feat/fp32-lm-head-out-dtype
May 8, 2026
Merged

feat(inference): fp32 lm_head via native bf16xbf16 -> fp32 mm (alt to #2438)#2441
samsja merged 1 commit into
mainfrom
feat/fp32-lm-head-out-dtype

Conversation

@samsja
Copy link
Copy Markdown
Member

@samsja samsja commented May 8, 2026

Summary

Alternative implementation of inference.enable_fp32_lm_head (sibling of #2438), following @Jackmin801's suggestion: use torch.mm(hidden_states, lm_head.weight.T, out_dtype=torch.float32) (PyTorch >= 2.10) instead of explicitly upcasting both operands.

Why this version

@Jackmin801 pointed out on #2438 that the operand-upcast approach (hidden_states.float() @ lm_head.weight.float().T) only differs from the bf16 path in the epilogue — bf16 weights cast to fp32 are zero-padded mantissa, no extra precision in the operands; the actual win is keeping the fp32 GEMM output instead of truncating to bf16. Native out_dtype=fp32 does that directly with bf16 operands and a fp32 accumulator/epilogue.

Trade-offs vs #2438 (operand upcast + cache):

  • ✅ No fp32 weight copy on the layer (~600MB saved for Qwen3-30B-class vocab/hidden)
  • ✅ No cache invalidation logic (RL weight broadcasts mutate lm_head.weight in place; feat(inference): fp32 lm_head for stable logprobs under FP8/bf16 #2438 needs Tensor._version checks to avoid serving stale step-0 weights — see bugbot comment)
  • ✅ Simpler code: 25-line patch body vs ~50
  • ⚠️ Requires PyTorch >= 2.10 for torch.mm out_dtype kwarg

Same wiring as #2438: inference.enable_fp32_lm_head -> additional_config["fp32_lm_head"] -> captured once on LogitsProcessor.__init__ -> per-instance flag read in _get_logits.

Usage

[inference]
enable_fp32_lm_head = true

Notes

  • Tracks vllm-project/vllm#24567 — the upstream PR uses operand upcast; if it adopts native out_dtype, our patch becomes a single-line removal.
  • No tied-embeddings restriction needed here — we don't mutate lm_head.weight, so tied configs are safe (just slower because we still use lm_head.weight regardless of whether it aliases embed_tokens).

🤖 Generated with Claude Code


Note

Medium Risk
Touches vLLM logits computation via monkey-patching, which can affect numerical behavior and performance across all inference paths when enabled. Gated behind a config flag but depends on PyTorch support for torch.mm(..., out_dtype=...).

Overview
Adds an opt-in inference.enable_fp32_lm_head flag that threads through to vLLM via additional_config["fp32_lm_head"].

When enabled, installs a worker-side monkey patch on vLLM’s LogitsProcessor to compute the lm_head matmul with torch.mm(..., out_dtype=torch.float32) (bf16×bf16 -> fp32) to improve logprob stability under lower-precision inference, while leaving the default path unchanged when the flag is off.

Reviewed by Cursor Bugbot for commit 7c9cfe4. Bugbot is set up for automated code reviews on this repo. Configure here.

Alternative implementation of fp32 lm_head proposed by @Jackmin801 on
PR #2438: use torch.mm(..., out_dtype=torch.float32) (PyTorch >= 2.10) so
the matmul accumulates and emits fp32 directly without zero-padding the
bf16 operands or maintaining a separate fp32 weight copy.

The win is purely in the epilogue: F.linear(bf16, bf16) accumulates in
fp32 internally on tensor cores but truncates the output to bf16 before
returning. Native out_dtype=fp32 keeps the full mantissa for the
downstream softmax/sample.

Removes the cached fp32 weight + tensor-version invalidation logic from
PR #2438 — not needed when there's no separate fp32 copy.

Wired identically to PR #2438: inference.enable_fp32_lm_head config flag
-> additional_config["fp32_lm_head"] -> captured on LogitsProcessor.__init__
-> per-instance flag read in _get_logits.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@samsja samsja marked this pull request as ready for review May 8, 2026 05:26
Copy link
Copy Markdown
Member

@Jackmin801 Jackmin801 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! lgtm

@samsja samsja merged commit 5f2c004 into main May 8, 2026
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants