Skip to content

feat(inference): fp32 lm_head for stable logprobs under FP8/bf16#2438

Open
samsja wants to merge 1 commit into
mainfrom
feat/fp32-lm-head
Open

feat(inference): fp32 lm_head for stable logprobs under FP8/bf16#2438
samsja wants to merge 1 commit into
mainfrom
feat/fp32-lm-head

Conversation

@samsja
Copy link
Copy Markdown
Member

@samsja samsja commented May 7, 2026

Summary

  • New inference.enable_fp32_lm_head flag that runs the lm_head projection in fp32 instead of the model dtype (bf16). Both hidden_states and lm_head.weight are cast to float32 before the matmul.
  • Wired through vLLM's additional_config dict (no env vars). The launcher's to_vllm() injects {"fp32_lm_head": True} when the flag is set.
  • The flag is captured once, on LogitsProcessor.__init__, onto the instance. The per-forward check then just reads self._fp32_lm_head_enabled. This matters because get_current_vllm_config() only resolves during model init — calling it from _get_logits during serving raises AssertionError (we verified empirically that the per-call approach silently no-ops).
  • The fp32 weight is cached on the layer, but invalidated against Tensor._version so RL weight broadcasts (NCCL / filesystem param.copy_(...)) don't leave a stale step-0 fp32 copy serving forever.
  • Asserts on tie_word_embeddings=True since casting lm_head.weight in place would also cast embed_tokens.

Why

Under FP8/bf16 inference the only step between hidden states and the sampler that's not fp32 is the lm_head matmul itself — vLLM's sampler already promotes to fp32 before softmax/sampling. Casting the operands to fp32 doesn't add precision to the bf16 weight values themselves (zero-padded mantissa), but it does keep the matmul output in fp32 instead of truncating to bf16 — which is what the sampler actually needs. SGLang implements the same flag the same way (logits_processor.py:894-897 in their tree).

Closes the rollout-vs-trainer logprob gap that destabilizes RL importance sampling. Background: MiniMax-M1 §7.6, ScaleRL, Meta arxiv:2510.13786. Tracks vllm-project/vllm#24567, restricted here to the untied-embeddings case.

Validation

Two single-node SLURM jobs on Qwen3-8B-Base + hendrycks-math, identical apart from enable_fp32_lm_head. The fp32 run shows a measurably lower train/inference logprob mismatch on wandb. Confirmed the patched path actually executes (vs silently falling through) by temporarily injecting 1/0 inside the fp32 branch — the worker crashed with ZeroDivisionError on the first forward, proving the init-time capture is correct.

Usage

[inference]
enable_fp32_lm_head = true

Notes

  • Cost: one extra fp32 copy of lm_head.weight (~600MB for Qwen3-30B vocab/hidden), rebuilt only when the underlying weight is mutated. The fp32 GEMM at the end is negligible vs the rest of the model.
  • Limitation: no support for tie_word_embeddings=True — explicit NotImplementedError.
  • Removable once vllm-project/vllm#24567 (or a successor) lands upstream.

🤖 Generated with Claude Code

@samsja samsja force-pushed the feat/fp32-lm-head branch 2 times, most recently from 2984e45 to 3efcdd8 Compare May 8, 2026 01:41
@samsja samsja marked this pull request as ready for review May 8, 2026 01:44
@samsja samsja force-pushed the feat/fp32-lm-head branch from 3efcdd8 to b09816b Compare May 8, 2026 01:45
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit b09816b. Configure here.

Comment thread src/prime_rl/inference/patches.py Outdated
Comment thread src/prime_rl/inference/patches.py Outdated
Adds inference.enable_fp32_lm_head to run the lm_head projection in fp32
instead of the model dtype. Closes the rollout-vs-trainer logprob gap that
shows up in low-precision RL setups where the head matmul is the only
remaining bf16 step before the sampler (which already promotes to fp32).
Mirrors SGLang's --enable-fp32-lm-head and tracks vllm-project/vllm#24567,
restricted here to the untied-embedding case.

Wired through vLLM's additional_config dict so workers read the flag from
get_current_vllm_config() — no env var. Lazily caches lm_head.weight.float()
on the layer so the cast happens once.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@samsja samsja force-pushed the feat/fp32-lm-head branch from b09816b to 0d79bdd Compare May 8, 2026 02:11
Comment on lines +1131 to +1166
def _patched_init(self, *args, **kwargs):
_original_init(self, *args, **kwargs)
vllm_config = get_current_vllm_config()
# `or {}` defends against downstream code overriding additional_config to None
# (default is an empty dict, but defensive copying is cheap).
additional_config = vllm_config.additional_config or {}
self._fp32_lm_head_enabled = additional_config.get("fp32_lm_head", False)
if self._fp32_lm_head_enabled:
if getattr(vllm_config.model_config.hf_config, "tie_word_embeddings", False):
raise NotImplementedError(
"fp32_lm_head is not supported for models with tie_word_embeddings=True; "
"casting lm_head.weight in place would also cast embed_tokens. Use a model "
"with untied embeddings, or extend this patch to keep a separate fp32 copy."
)
logger.warning("fp32 lm_head ENABLED for this LogitsProcessor instance.")

def _patched_get_logits(self, hidden_states, lm_head, embedding_bias):
if not getattr(self, "_fp32_lm_head_enabled", False):
return _original_get_logits(self, hidden_states, lm_head, embedding_bias)

# Cache the fp32 weight, but invalidate it whenever lm_head.weight has
# been mutated. RL weight broadcasts (NCCL / filesystem) update params
# in place via param.copy_(...), which bumps Tensor._version. Detect
# that and rebuild — otherwise we'd serve stale step-0 weights forever.
weight = lm_head.weight
cur_version = weight._version
cached = getattr(lm_head, "_fp32_lm_head_weight", None)
cached_version = getattr(lm_head, "_fp32_lm_head_weight_version", None)
if cached is None or cached_version != cur_version:
with torch.no_grad():
cached = weight.detach().to(torch.float32)
lm_head._fp32_lm_head_weight = cached
lm_head._fp32_lm_head_weight_version = cur_version

bias_fp32 = embedding_bias.to(torch.float32) if embedding_bias is not None else None
logits = F.linear(hidden_states.to(torch.float32), cached, bias_fp32)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only thing we gain here is removing the epilogue bf16 downcast. we should be able to replace this with torch.mm(hidden_states, lm_head.weight.T, out_dtype=torch.float32) and do this more efficiently

samsja added a commit that referenced this pull request May 8, 2026
Alternative implementation of fp32 lm_head proposed by @Jackmin801 on
PR #2438: use torch.mm(..., out_dtype=torch.float32) (PyTorch >= 2.10) so
the matmul accumulates and emits fp32 directly without zero-padding the
bf16 operands or maintaining a separate fp32 weight copy.

The win is purely in the epilogue: F.linear(bf16, bf16) accumulates in
fp32 internally on tensor cores but truncates the output to bf16 before
returning. Native out_dtype=fp32 keeps the full mantissa for the
downstream softmax/sample.

Removes the cached fp32 weight + tensor-version invalidation logic from
PR #2438 — not needed when there's no separate fp32 copy.

Wired identically to PR #2438: inference.enable_fp32_lm_head config flag
-> additional_config["fp32_lm_head"] -> captured on LogitsProcessor.__init__
-> per-instance flag read in _get_logits.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants