Skip to content

Commit aeff9c8

Browse files
committed
add kv cache quantization for mcore using bmm_quantizers
Signed-off-by: Kai Xu <[email protected]>
1 parent 598b9ce commit aeff9c8

File tree

3 files changed

+326
-1
lines changed

3 files changed

+326
-1
lines changed

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,6 @@
7171

7272
with import_plugin("trl"):
7373
from .trl import *
74+
75+
with import_plugin("mcore"):
76+
from .mcore import *

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,19 @@
2626
from megatron.core.transformer import MegatronModule
2727
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2828
from megatron.core.utils import get_tensor_model_parallel_group_if_none
29+
from megatron.core.extensions.transformer_engine import TEDotProductAttention
2930

3031
from modelopt.torch.opt.plugins.megatron import (
3132
_MegatronMLP,
3233
register_modelopt_extra_state_callbacks,
3334
)
3435
from modelopt.torch.utils.distributed import ParallelState
3536

36-
from ..nn import QuantModuleRegistry, TensorQuantizer
37+
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
3738
from ..nn.modules.quant_linear import RealQuantLinear
3839
from ..qtensor import QTensorWrapper
3940
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
41+
from ..model_calib import max_calibrate
4042

4143
__all__ = []
4244

@@ -460,3 +462,149 @@ class _RealQuantMegatronRowParallelLinear(
460462

461463
def forward(self, input, *args, **kwargs):
462464
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
465+
466+
467+
@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"})
468+
class _QuantTEDotProductAttention(QuantModule):
469+
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization."""
470+
471+
def _setup(self):
472+
"""Initialize quantizers for Q, K, V tensors."""
473+
self.q_bmm_quantizer = TensorQuantizer()
474+
self.k_bmm_quantizer = TensorQuantizer()
475+
self.v_bmm_quantizer = TensorQuantizer()
476+
477+
def _calibrate_quantizers(self):
478+
"""Calibrate quantizers with minimal dummy tensors."""
479+
# Get device from parent module parameters
480+
device = next(self.parameters()).device if self.parameters() else torch.device('cuda')
481+
482+
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
483+
batch_size = 1
484+
seq_len = 1
485+
486+
# Get dimensions from config
487+
num_heads = self.config.num_attention_heads
488+
head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads
489+
490+
# Determine tensor format (default to sbhd if not specified)
491+
apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False)
492+
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
493+
494+
if qkv_format == "sbhd":
495+
dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16)
496+
else:
497+
dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16)
498+
499+
# Calibrate each quantizer
500+
quantizers = [
501+
("q_bmm_quantizer", self.q_bmm_quantizer),
502+
("k_bmm_quantizer", self.k_bmm_quantizer),
503+
("v_bmm_quantizer", self.v_bmm_quantizer),
504+
]
505+
506+
for _, quantizer in quantizers:
507+
if quantizer is not None and quantizer.is_enabled:
508+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
509+
quantizer.reset_amax()
510+
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
511+
512+
def forward(self, query, key, value, *args, **kwargs):
513+
"""Apply post-RoPE quantization to KV cache.
514+
515+
TEDotProductAttention receives Q, K, V after RoPE is applied,
516+
so we quantize them directly for KV cache quantization.
517+
"""
518+
# Quantize Q, K, V
519+
query = self.q_bmm_quantizer(query)
520+
key = self.k_bmm_quantizer(key)
521+
value = self.v_bmm_quantizer(value)
522+
523+
return super().forward(query, key, value, *args, **kwargs)
524+
525+
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
526+
"""Create a sharded state dictionary for distributed checkpointing."""
527+
sharded_state_dict = {}
528+
529+
# First add non-quantizer parameters
530+
for k, v in self.state_dict(prefix="", keep_vars=True).items():
531+
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
532+
sharded_state_dict[prefix + k] = v
533+
534+
# Process _amax in bmm_quantizers
535+
for name, quantizer in [
536+
("q_bmm_quantizer", self.q_bmm_quantizer),
537+
("k_bmm_quantizer", self.k_bmm_quantizer),
538+
("v_bmm_quantizer", self.v_bmm_quantizer),
539+
]:
540+
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
541+
amax_key = f"{prefix}{name}._amax"
542+
sharded_state_dict[amax_key] = quantizer._amax
543+
544+
# Process other quantizer parameters in bmm_quantizers
545+
quantizer_state_dict = {}
546+
for k, v in self.state_dict(prefix="", keep_vars=True).items():
547+
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k:
548+
quantizer_state_dict[k] = v
549+
550+
if quantizer_state_dict:
551+
sharded_state_dict.update(
552+
**make_sharded_tensors_for_checkpoint(
553+
quantizer_state_dict, prefix, {}, sharded_offsets
554+
)
555+
)
556+
557+
return sharded_state_dict
558+
559+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
560+
"""Handle loading state dict for quantizers."""
561+
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
562+
full_prefix = f"{prefix}{quantizer_name}."
563+
amax_key = f"{prefix}{quantizer_name}._amax"
564+
565+
# If amax is in state_dict, rename it to the format expected by TensorQuantizer
566+
if amax_key in state_dict:
567+
expected_amax_key = f"{full_prefix}_amax"
568+
state_dict[expected_amax_key] = state_dict.pop(amax_key)
569+
570+
# Handle other quantizer states
571+
for k in list(state_dict.keys()):
572+
if "_quantizer" in k and "_amax" not in k:
573+
name = k.split(prefix)[-1] if prefix else k
574+
if name in self.state_dict():
575+
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])
576+
577+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
578+
579+
def modelopt_post_restore(self, name=""):
580+
"""Restore quantizer states after model loading."""
581+
super().modelopt_post_restore(name)
582+
583+
def _check_unsupported_states(quantizer):
584+
if not hasattr(quantizer, "state_dict"):
585+
return
586+
587+
for k in quantizer.state_dict().keys():
588+
if k not in ["_amax", "_pre_quant_scale"]:
589+
warnings.warn(
590+
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
591+
f"incorrect. Please implement a custom restore for {k}."
592+
)
593+
594+
calibration_needed = False
595+
596+
for quantizer_name, quantizer in [
597+
("q_bmm_quantizer", self.q_bmm_quantizer),
598+
("k_bmm_quantizer", self.k_bmm_quantizer),
599+
("v_bmm_quantizer", self.v_bmm_quantizer),
600+
]:
601+
if not hasattr(self, quantizer_name) or not quantizer.is_enabled:
602+
continue
603+
604+
_check_unsupported_states(quantizer)
605+
606+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
607+
calibration_needed = True
608+
609+
if calibration_needed:
610+
self._calibrate_quantizers()

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pytest
1919
import torch
20+
import torch.nn as nn
2021
from _test_utils.import_helper import skip_if_no_megatron
2122
from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job
2223
from _test_utils.torch_dist.plugins.megatron_common import (
@@ -232,6 +233,8 @@ def forward_fn(model):
232233
mtq.W4A8_AWQ_BETA_CFG,
233234
mtq.NVFP4_DEFAULT_CFG,
234235
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
236+
# Note: KV cache configs (FP8_KV_CFG, NVFP4_KV_CFG) are tested separately in test_kv_cache_quant
237+
# They require TEDotProductAttention which needs transformer_impl="modelopt", not "local"
235238
],
236239
)
237240
@pytest.mark.parametrize("compress", [False, True])
@@ -367,3 +370,174 @@ def forward_fn(model):
367370
def test_fp8_real_quantize():
368371
size = torch.cuda.device_count()
369372
spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl")
373+
374+
375+
def _test_kv_cache_quant_helper(config, rank, size):
376+
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
377+
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
378+
379+
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
380+
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
381+
model = get_mcore_gpt_model(
382+
tensor_model_parallel_size=size,
383+
num_layers=1,
384+
hidden_size=64,
385+
num_attention_heads=4,
386+
vocab_size=32,
387+
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
388+
).cuda()
389+
390+
# Create dummy input for calibration
391+
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
392+
393+
def forward_fn(model):
394+
return megatron_prefill(model, prompt_tokens)
395+
396+
# Test KV cache quantization with the given config
397+
quantized_model = mtq.quantize(model, config, forward_fn)
398+
399+
# Find TEDotProductAttention modules and verify they have KV cache quantizers
400+
te_attention_found = False
401+
for name, module in quantized_model.named_modules():
402+
# Check if this is a quantized TEDotProductAttention
403+
if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'):
404+
te_attention_found = True
405+
# Verify all expected quantizers exist
406+
assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}"
407+
408+
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
409+
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
410+
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
411+
412+
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
413+
414+
# Quick smoke test that forward still works
415+
output = forward_fn(quantized_model)
416+
assert output is not None, "Forward pass failed"
417+
418+
419+
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
420+
"""Helper for testing KV cache quantization with sharded state dict save/load."""
421+
# Disable output_layer quantization (same as other sharded state dict tests)
422+
config["quant_cfg"]["*output_layer*"] = {"enable": False}
423+
424+
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
425+
426+
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
427+
model_ref = get_mcore_gpt_model(
428+
tensor_model_parallel_size=size,
429+
num_layers=2, # At least 2 layers to test multiple attention modules
430+
hidden_size=64,
431+
num_attention_heads=4,
432+
vocab_size=64,
433+
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
434+
).cuda()
435+
436+
model_test = get_mcore_gpt_model(
437+
tensor_model_parallel_size=size,
438+
num_layers=2,
439+
hidden_size=64,
440+
num_attention_heads=4,
441+
vocab_size=64,
442+
transformer_impl="modelopt",
443+
).cuda()
444+
445+
prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda()
446+
447+
def forward_fn(model):
448+
return megatron_prefill(model, prompt_tokens)
449+
450+
# Quantize the reference model
451+
model_ref = mtq.quantize(model_ref, config, forward_fn)
452+
453+
# CRITICAL: model_test must also be quantized with the same config
454+
# Otherwise it won't have the KV cache quantizer keys when loading state dict
455+
model_test = mtq.quantize(model_test, config, forward_fn)
456+
457+
# Verify KV cache quantizers were created
458+
kv_quantizers_found = False
459+
for name, module in model_ref.named_modules():
460+
if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'):
461+
kv_quantizers_found = True
462+
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
463+
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
464+
465+
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
466+
467+
# Test sharded state dict save/load
468+
sharded_state_dict_test_helper(
469+
tmp_path,
470+
model_ref,
471+
model_test,
472+
forward_fn,
473+
meta_device=False,
474+
version=None,
475+
)
476+
477+
# Verify KV cache quantizers are restored correctly in model_test
478+
for (name_ref, module_ref), (name_test, module_test) in zip(
479+
model_ref.named_modules(), model_test.named_modules()
480+
):
481+
if hasattr(module_ref, 'k_bmm_quantizer'):
482+
assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}"
483+
assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}"
484+
485+
# Check that quantizer states match
486+
if hasattr(module_ref.k_bmm_quantizer, '_amax'):
487+
assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}"
488+
if module_ref.k_bmm_quantizer._amax is not None:
489+
assert torch.allclose(
490+
module_ref.k_bmm_quantizer._amax,
491+
module_test.k_bmm_quantizer._amax
492+
), f"K quantizer _amax mismatch in {name_test}"
493+
494+
if hasattr(module_ref.v_bmm_quantizer, '_amax'):
495+
assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}"
496+
if module_ref.v_bmm_quantizer._amax is not None:
497+
assert torch.allclose(
498+
module_ref.v_bmm_quantizer._amax,
499+
module_test.v_bmm_quantizer._amax
500+
), f"V quantizer _amax mismatch in {name_test}"
501+
502+
503+
@pytest.mark.parametrize(
504+
"config",
505+
[
506+
mtq.FP8_KV_CFG,
507+
mtq.NVFP4_KV_CFG,
508+
],
509+
)
510+
def test_kv_cache_quant(config):
511+
"""Verify KV cache quantization works correctly with TEDotProductAttention.
512+
513+
This test ensures TEDotProductAttention is properly registered and gets the
514+
expected q/k/v_bmm_quantizers when using KV cache configs.
515+
516+
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
517+
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
518+
"""
519+
spawn_multiprocess_job(
520+
size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl"
521+
)
522+
523+
524+
@pytest.mark.parametrize(
525+
"config",
526+
[
527+
mtq.FP8_KV_CFG,
528+
mtq.NVFP4_KV_CFG,
529+
],
530+
)
531+
def test_kv_cache_sharded_state_dict(tmp_path, config):
532+
"""Test KV cache quantization with sharded state dict save/load.
533+
534+
This test verifies the complete workflow of saving and loading KV cache quantized
535+
models with distributed checkpointing, ensuring quantizer states are properly
536+
preserved across the save/load cycle.
537+
"""
538+
size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1
539+
spawn_multiprocess_job(
540+
size=size,
541+
job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config),
542+
backend="nccl"
543+
)

0 commit comments

Comments
 (0)