Skip to content

Fp8 layerwise#2415

Open
S1ro1 wants to merge 2 commits into
mainfrom
fp8-layerwise
Open

Fp8 layerwise#2415
S1ro1 wants to merge 2 commits into
mainfrom
fp8-layerwise

Conversation

@S1ro1
Copy link
Copy Markdown
Collaborator

@S1ro1 S1ro1 commented May 4, 2026

Note

Medium Risk
Changes the NCCL weight-broadcast/update path across trainer, orchestrator, and vLLM workers, including new FP8 quantization logic and monkeypatching vLLM reload internals; mistakes could cause incorrect weights, OOMs, or failed updates during training.

Overview
Enables FP8 “layerwise” weight updates when the inference model differs from the trainer, by auto-detecting a quant_scheme (fp8_blockwise/fp8_channelwise) from the inference model config and propagating it through shared weight_broadcast configs and the /init_broadcaster RPC.

On the trainer side, adds a new FP8 dispatch module (fp8_dispatch.py) and channelwise quantization, then uses these to optionally quantize each layer into HF-checkpoint format before NCCL broadcast. On the inference side, the NCCL worker can now switch between kernel-quantized loads, checkpoint loads, or vLLM layerwise reload flow, and adds vLLM monkeypatches to suppress layerwise-reload warning spam and skip problematic tensors to avoid meta-materialization OOMs (pending vLLM >= 0.20).

Reviewed by Cursor Bugbot for commit 6e9df4e. Bugbot is set up for automated code reviews on this repo. Configure here.

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 6e9df4e. Configure here.

row_max = weight.float().abs().amax(dim=1, keepdim=True)
scales = (row_max / fp8_max).clamp(min=1e-12)
quantized = (weight.float() / scales).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn)
return quantized.contiguous(), scales.float().contiguous()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Channelwise FP8 scales have wrong shape dimension

High Severity

quantize_to_fp8_channelwise returns scales with shape (rows, 1) due to keepdim=True in the amax call. FP8 channelwise (per-channel) scale tensors in HF checkpoints and vLLM's ChannelQuantScaleParameter are expected to be 1D (rows,). When these 2D scales are sent over NCCL and loaded via model.load_weights(), vLLM's weight loaders will fail on shape mismatch during copy_() into the model's 1D scale parameters. The keepdim=True is needed for the broadcasting division but the scales need to be squeezed before returning.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6e9df4e. Configure here.

return FP8BlockwiseQuantScheme()
if name == "fp8_channelwise":
return FP8ChannelwiseQuantScheme()
raise ValueError(f"Unknown quant scheme: {name}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused get_quant_scheme function is dead code

Low Severity

get_quant_scheme is defined but never called anywhere in the codebase. A grep confirms the only match is its own definition. The config stores the scheme name string but uses detect_quant_scheme (which returns the full scheme object) at both call sites. This function is dead code that adds maintenance burden.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6e9df4e. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant