Fp8 layerwise#2415
Conversation
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 6e9df4e. Configure here.
| row_max = weight.float().abs().amax(dim=1, keepdim=True) | ||
| scales = (row_max / fp8_max).clamp(min=1e-12) | ||
| quantized = (weight.float() / scales).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn) | ||
| return quantized.contiguous(), scales.float().contiguous() |
There was a problem hiding this comment.
Channelwise FP8 scales have wrong shape dimension
High Severity
quantize_to_fp8_channelwise returns scales with shape (rows, 1) due to keepdim=True in the amax call. FP8 channelwise (per-channel) scale tensors in HF checkpoints and vLLM's ChannelQuantScaleParameter are expected to be 1D (rows,). When these 2D scales are sent over NCCL and loaded via model.load_weights(), vLLM's weight loaders will fail on shape mismatch during copy_() into the model's 1D scale parameters. The keepdim=True is needed for the broadcasting division but the scales need to be squeezed before returning.
Reviewed by Cursor Bugbot for commit 6e9df4e. Configure here.
| return FP8BlockwiseQuantScheme() | ||
| if name == "fp8_channelwise": | ||
| return FP8ChannelwiseQuantScheme() | ||
| raise ValueError(f"Unknown quant scheme: {name}") |
There was a problem hiding this comment.
Unused get_quant_scheme function is dead code
Low Severity
get_quant_scheme is defined but never called anywhere in the codebase. A grep confirms the only match is its own definition. The config stores the scheme name string but uses detect_quant_scheme (which returns the full scheme object) at both call sites. This function is dead code that adds maintenance burden.
Reviewed by Cursor Bugbot for commit 6e9df4e. Configure here.


Note
Medium Risk
Changes the NCCL weight-broadcast/update path across trainer, orchestrator, and vLLM workers, including new FP8 quantization logic and monkeypatching vLLM reload internals; mistakes could cause incorrect weights, OOMs, or failed updates during training.
Overview
Enables FP8 “layerwise” weight updates when the inference model differs from the trainer, by auto-detecting a
quant_scheme(fp8_blockwise/fp8_channelwise) from the inference model config and propagating it through sharedweight_broadcastconfigs and the/init_broadcasterRPC.On the trainer side, adds a new FP8 dispatch module (
fp8_dispatch.py) and channelwise quantization, then uses these to optionally quantize each layer into HF-checkpoint format before NCCL broadcast. On the inference side, the NCCL worker can now switch between kernel-quantized loads, checkpoint loads, or vLLM layerwise reload flow, and adds vLLM monkeypatches to suppress layerwise-reload warning spam and skip problematic tensors to avoid meta-materialization OOMs (pending vLLM >= 0.20).Reviewed by Cursor Bugbot for commit 6e9df4e. Bugbot is set up for automated code reviews on this repo. Configure here.