Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
Model Optimizer Changelog (Linux)
=================================
0.41 (2025-12-xx)
^^^^^^^^^^^^^^^^^

**Deprecations**

**New Features**
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.


0.40 (2025-12-xx)
^^^^^^^^^^^^^^^^^
Expand Down
17 changes: 1 addition & 16 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import sys
import warnings
from pathlib import Path
from typing import Any

import torch
import transformers
Expand Down Expand Up @@ -159,7 +158,7 @@ def build_quant_cfg(

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = apply_kv_cache_quant(
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
)
Expand Down Expand Up @@ -403,20 +402,6 @@ def is_enc_dec(model_type) -> bool:
return model_type in ["t5", "bart", "whisper"]


def apply_kv_cache_quant(quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str, Any]):
"""Apply quantization to the kv cache of the model."""
# Update KV cache related bmm quantizers
# If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case
quant_cfg["quant_cfg"] = quant_cfg.get("quant_cfg", {"default": {"enable": False}})
quant_cfg["quant_cfg"].update(kv_cache_quant_cfg)

# Set default algorithm for kv cache quantization if not provided.
if not quant_cfg.get("algorithm"):
quant_cfg["algorithm"] = "max"

return quant_cfg


def _resolve_model_path(model_name_or_path: str, trust_remote_code: bool = False) -> str:
"""Resolve a model name or path to a local directory path.

Expand Down
5 changes: 3 additions & 2 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
from accelerate.hooks import remove_hook_from_module
from example_utils import (
apply_kv_cache_quant,
build_quant_cfg,
copy_custom_model_files,
get_model,
Expand Down Expand Up @@ -86,8 +85,10 @@
KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8": "FP8_KV_CFG",
"fp8_affine": "FP8_AFFINE_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
}

mto.enable_huggingface_checkpointing()
Expand Down Expand Up @@ -257,7 +258,7 @@ def main(args):
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
if args.kv_cache_qformat != "none":
quant_cfg = apply_kv_cache_quant(
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)

Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .conversion import *
from .model_quant import *
from .nn.modules.quant_module import QuantModuleRegistry
from .utils import update_quant_cfg_with_kv_cache_quant
165 changes: 165 additions & 0 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from modelopt.torch.utils.distributed import ParallelState

from ..model_calib import max_calibrate
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
Expand All @@ -45,6 +46,7 @@
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEDotProductAttention,
TERowParallelGroupedLinear,
)

Expand Down Expand Up @@ -590,6 +592,169 @@ def _setup(self):
self.linear_fc1.parallel_state = self.parallel_state
self.linear_fc2.parallel_state = self.parallel_state

@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"})
class _QuantTEDotProductAttention(QuantModule):
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.

This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention
module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer,
k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after
RoPE has been applied.
"""

def _setup(self):
"""Initialize quantizers for Q, K, V tensors."""
self.q_bmm_quantizer = TensorQuantizer()
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()

def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device and dtype from the parent module's parameters
param = next(iter(self.parameters()), None)
device = param.device if param is not None else torch.device("cuda")
dtype = param.dtype if param is not None else torch.float16

# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1

# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)

# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"

if qkv_format == "sbhd":
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
)
else:
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
)

# Calibrate each quantizer
quantizers = [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]

for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled():
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)

def forward(self, query, key, value, *args, **kwargs):
"""Apply post-RoPE quantization to KV cache.

TEDotProductAttention receives Q, K, V after RoPE is applied,
so we quantize them directly for KV cache quantization.
"""
# Quantize Q, K, V
query = self.q_bmm_quantizer(query)
key = self.k_bmm_quantizer(key)
value = self.v_bmm_quantizer(value)

return super().forward(query, key, value, *args, **kwargs)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Create a sharded state dictionary for distributed checkpointing."""
sharded_state_dict = {}

# First add non-quantizer parameters
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
sharded_state_dict[prefix + k] = v

# Process _amax in bmm_quantizers
for name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
amax_key = f"{prefix}{name}._amax"
sharded_state_dict[amax_key] = quantizer._amax

# Process other quantizer parameters in bmm_quantizers
quantizer_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
}

if quantizer_state_dict:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
quantizer_state_dict, prefix, {}, sharded_offsets
)
)

return sharded_state_dict

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Handle loading state dict for quantizers."""
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
full_prefix = f"{prefix}{quantizer_name}."
amax_key = f"{prefix}{quantizer_name}._amax"

# If amax is in state_dict, rename it to the format expected by TensorQuantizer
if amax_key in state_dict:
expected_amax_key = f"{full_prefix}_amax"
state_dict[expected_amax_key] = state_dict.pop(amax_key)

# Handle other quantizer states
for k in list(state_dict.keys()):
if "_quantizer" in k and "_amax" not in k:
name = k.split(prefix)[-1] if prefix else k
if name in self.state_dict():
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def modelopt_post_restore(self, name=""):
"""Restore quantizer states after model loading."""
super().modelopt_post_restore(name)

def _check_unsupported_states(quantizer):
"""Check for unsupported quantizer states and warn if found."""
if not hasattr(quantizer, "state_dict"):
return

for k in quantizer.state_dict():
if k not in ["_amax", "_pre_quant_scale"]:
warnings.warn(
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
f"incorrect. Please implement a custom restore for {k}."
)

calibration_needed = False

for quantizer_name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
continue

_check_unsupported_states(quantizer)

if not hasattr(quantizer, "_amax") or quantizer._amax is None:
calibration_needed = True

if calibration_needed:
self._calibrate_quantizers()


@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
class _QuantMoELayer(QuantModule):
Expand Down
18 changes: 17 additions & 1 deletion modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from collections import namedtuple
from contextlib import ExitStack, contextmanager, nullcontext
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import torch
import torch.nn as nn
Expand All @@ -43,6 +43,7 @@
"is_quantized_row_parallel_linear",
"reduce_amax",
"replace_function",
"update_quant_cfg_with_kv_cache_quant",
"weight_attr_names",
]

Expand Down Expand Up @@ -703,3 +704,18 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
if reshard:
with enable_fake_quant(root_module):
root_module.reshard()


def update_quant_cfg_with_kv_cache_quant(
quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str, Any]
) -> dict[str, Any]:
"""Update the quant_cfg with the kv cache quant_cfg."""
# If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case
quant_cfg["quant_cfg"] = quant_cfg.get("quant_cfg", {"default": {"enable": False}})
quant_cfg["quant_cfg"].update(kv_cache_quant_cfg)

# Set default algorithm for kv cache quantization if not provided.
if not quant_cfg.get("algorithm"):
quant_cfg["algorithm"] = "max"
print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}")
return quant_cfg
Loading