Skip to content

Commit 5fc56fe

Browse files
committed
add kv cache quantization for mcore using bmm_quantizers
Signed-off-by: Kai Xu <[email protected]>
1 parent aeff9c8 commit 5fc56fe

File tree

3 files changed

+98
-71
lines changed

3 files changed

+98
-71
lines changed

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@
7373
from .trl import *
7474

7575
with import_plugin("mcore"):
76-
from .mcore import *
76+
from .mcore import *

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,23 @@
2222
import megatron.core.tensor_parallel.layers as megatron_parallel
2323
import megatron.core.transformer.mlp as megatron_mlp
2424
import torch
25+
from megatron.core.extensions.transformer_engine import TEDotProductAttention
2526
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2627
from megatron.core.transformer import MegatronModule
2728
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2829
from megatron.core.utils import get_tensor_model_parallel_group_if_none
29-
from megatron.core.extensions.transformer_engine import TEDotProductAttention
3030

3131
from modelopt.torch.opt.plugins.megatron import (
3232
_MegatronMLP,
3333
register_modelopt_extra_state_callbacks,
3434
)
3535
from modelopt.torch.utils.distributed import ParallelState
3636

37+
from ..model_calib import max_calibrate
3738
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
3839
from ..nn.modules.quant_linear import RealQuantLinear
3940
from ..qtensor import QTensorWrapper
4041
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
41-
from ..model_calib import max_calibrate
4242

4343
__all__ = []
4444

@@ -466,7 +466,13 @@ def forward(self, input, *args, **kwargs):
466466

467467
@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"})
468468
class _QuantTEDotProductAttention(QuantModule):
469-
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization."""
469+
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.
470+
471+
This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention
472+
module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer,
473+
k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after
474+
RoPE has been applied.
475+
"""
470476

471477
def _setup(self):
472478
"""Initialize quantizers for Q, K, V tensors."""
@@ -476,25 +482,35 @@ def _setup(self):
476482

477483
def _calibrate_quantizers(self):
478484
"""Calibrate quantizers with minimal dummy tensors."""
479-
# Get device from parent module parameters
480-
device = next(self.parameters()).device if self.parameters() else torch.device('cuda')
481-
485+
# Get device and dtype from the parent module's parameters
486+
param = next(iter(self.parameters()), None)
487+
device = param.device if param is not None else torch.device("cuda")
488+
dtype = param.dtype if param is not None else torch.float16
489+
482490
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
483491
batch_size = 1
484492
seq_len = 1
485-
493+
486494
# Get dimensions from config
487495
num_heads = self.config.num_attention_heads
488-
head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads
489-
496+
head_dim = (
497+
self.config.kv_channels
498+
if hasattr(self.config, "kv_channels")
499+
else self.config.hidden_size // num_heads
500+
)
501+
490502
# Determine tensor format (default to sbhd if not specified)
491-
apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False)
503+
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
492504
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
493505

494506
if qkv_format == "sbhd":
495-
dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16)
507+
dummy_tensor = torch.randn(
508+
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
509+
)
496510
else:
497-
dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16)
511+
dummy_tensor = torch.randn(
512+
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
513+
)
498514

499515
# Calibrate each quantizer
500516
quantizers = [
@@ -504,14 +520,14 @@ def _calibrate_quantizers(self):
504520
]
505521

506522
for _, quantizer in quantizers:
507-
if quantizer is not None and quantizer.is_enabled:
523+
if quantizer is not None and quantizer.is_enabled():
508524
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
509525
quantizer.reset_amax()
510526
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
511527

512528
def forward(self, query, key, value, *args, **kwargs):
513529
"""Apply post-RoPE quantization to KV cache.
514-
530+
515531
TEDotProductAttention receives Q, K, V after RoPE is applied,
516532
so we quantize them directly for KV cache quantization.
517533
"""
@@ -525,7 +541,7 @@ def forward(self, query, key, value, *args, **kwargs):
525541
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
526542
"""Create a sharded state dictionary for distributed checkpointing."""
527543
sharded_state_dict = {}
528-
544+
529545
# First add non-quantizer parameters
530546
for k, v in self.state_dict(prefix="", keep_vars=True).items():
531547
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
@@ -542,10 +558,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
542558
sharded_state_dict[amax_key] = quantizer._amax
543559

544560
# Process other quantizer parameters in bmm_quantizers
545-
quantizer_state_dict = {}
546-
for k, v in self.state_dict(prefix="", keep_vars=True).items():
547-
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k:
548-
quantizer_state_dict[k] = v
561+
quantizer_state_dict = {
562+
k: v
563+
for k, v in self.state_dict(prefix="", keep_vars=True).items()
564+
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
565+
}
549566

550567
if quantizer_state_dict:
551568
sharded_state_dict.update(
@@ -581,10 +598,11 @@ def modelopt_post_restore(self, name=""):
581598
super().modelopt_post_restore(name)
582599

583600
def _check_unsupported_states(quantizer):
601+
"""Check for unsupported quantizer states and warn if found."""
584602
if not hasattr(quantizer, "state_dict"):
585603
return
586604

587-
for k in quantizer.state_dict().keys():
605+
for k in quantizer.state_dict():
588606
if k not in ["_amax", "_pre_quant_scale"]:
589607
warnings.warn(
590608
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
@@ -598,7 +616,7 @@ def _check_unsupported_states(quantizer):
598616
("k_bmm_quantizer", self.k_bmm_quantizer),
599617
("v_bmm_quantizer", self.v_bmm_quantizer),
600618
]:
601-
if not hasattr(self, quantizer_name) or not quantizer.is_enabled:
619+
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
602620
continue
603621

604622
_check_unsupported_states(quantizer)

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import pytest
1919
import torch
20-
import torch.nn as nn
2120
from _test_utils.import_helper import skip_if_no_megatron
2221
from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job
2322
from _test_utils.torch_dist.plugins.megatron_common import (
@@ -374,8 +373,10 @@ def test_fp8_real_quantize():
374373

375374
def _test_kv_cache_quant_helper(config, rank, size):
376375
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
377-
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
378-
376+
initialize_for_megatron(
377+
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
378+
)
379+
379380
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
380381
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
381382
model = get_mcore_gpt_model(
@@ -386,43 +387,45 @@ def _test_kv_cache_quant_helper(config, rank, size):
386387
vocab_size=32,
387388
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
388389
).cuda()
389-
390+
390391
# Create dummy input for calibration
391392
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
392-
393+
393394
def forward_fn(model):
394395
return megatron_prefill(model, prompt_tokens)
395-
396+
396397
# Test KV cache quantization with the given config
397398
quantized_model = mtq.quantize(model, config, forward_fn)
398-
399+
399400
# Find TEDotProductAttention modules and verify they have KV cache quantizers
400401
te_attention_found = False
401402
for name, module in quantized_model.named_modules():
402403
# Check if this is a quantized TEDotProductAttention
403-
if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'):
404+
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
404405
te_attention_found = True
405406
# Verify all expected quantizers exist
406-
assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}"
407-
407+
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
408+
408409
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
409410
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
410411
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
411-
412+
412413
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
413-
414+
414415
# Quick smoke test that forward still works
415416
output = forward_fn(quantized_model)
416417
assert output is not None, "Forward pass failed"
417-
418+
418419

419420
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
420421
"""Helper for testing KV cache quantization with sharded state dict save/load."""
421422
# Disable output_layer quantization (same as other sharded state dict tests)
422423
config["quant_cfg"]["*output_layer*"] = {"enable": False}
423-
424-
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
425-
424+
425+
initialize_for_megatron(
426+
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
427+
)
428+
426429
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
427430
model_ref = get_mcore_gpt_model(
428431
tensor_model_parallel_size=size,
@@ -432,7 +435,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
432435
vocab_size=64,
433436
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
434437
).cuda()
435-
438+
436439
model_test = get_mcore_gpt_model(
437440
tensor_model_parallel_size=size,
438441
num_layers=2,
@@ -441,29 +444,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
441444
vocab_size=64,
442445
transformer_impl="modelopt",
443446
).cuda()
444-
445-
prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda()
446-
447+
448+
prompt_tokens = torch.randint(
449+
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
450+
).cuda()
451+
447452
def forward_fn(model):
448453
return megatron_prefill(model, prompt_tokens)
449-
454+
450455
# Quantize the reference model
451456
model_ref = mtq.quantize(model_ref, config, forward_fn)
452-
457+
453458
# CRITICAL: model_test must also be quantized with the same config
454459
# Otherwise it won't have the KV cache quantizer keys when loading state dict
455460
model_test = mtq.quantize(model_test, config, forward_fn)
456-
461+
457462
# Verify KV cache quantizers were created
458463
kv_quantizers_found = False
459464
for name, module in model_ref.named_modules():
460-
if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'):
465+
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
461466
kv_quantizers_found = True
462467
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
463468
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
464-
469+
465470
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
466-
471+
467472
# Test sharded state dict save/load
468473
sharded_state_dict_test_helper(
469474
tmp_path,
@@ -473,32 +478,38 @@ def forward_fn(model):
473478
meta_device=False,
474479
version=None,
475480
)
476-
481+
477482
# Verify KV cache quantizers are restored correctly in model_test
478483
for (name_ref, module_ref), (name_test, module_test) in zip(
479484
model_ref.named_modules(), model_test.named_modules()
480485
):
481-
if hasattr(module_ref, 'k_bmm_quantizer'):
482-
assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}"
483-
assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}"
484-
486+
if hasattr(module_ref, "k_bmm_quantizer"):
487+
assert hasattr(module_test, "k_bmm_quantizer"), (
488+
f"K quantizer missing after restore in {name_test}"
489+
)
490+
assert hasattr(module_test, "v_bmm_quantizer"), (
491+
f"V quantizer missing after restore in {name_test}"
492+
)
493+
485494
# Check that quantizer states match
486-
if hasattr(module_ref.k_bmm_quantizer, '_amax'):
487-
assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}"
495+
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
496+
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
497+
f"K quantizer _amax missing in {name_test}"
498+
)
488499
if module_ref.k_bmm_quantizer._amax is not None:
489500
assert torch.allclose(
490-
module_ref.k_bmm_quantizer._amax,
491-
module_test.k_bmm_quantizer._amax
501+
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
492502
), f"K quantizer _amax mismatch in {name_test}"
493-
494-
if hasattr(module_ref.v_bmm_quantizer, '_amax'):
495-
assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}"
503+
504+
if hasattr(module_ref.v_bmm_quantizer, "_amax"):
505+
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
506+
f"V quantizer _amax missing in {name_test}"
507+
)
496508
if module_ref.v_bmm_quantizer._amax is not None:
497509
assert torch.allclose(
498-
module_ref.v_bmm_quantizer._amax,
499-
module_test.v_bmm_quantizer._amax
510+
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
500511
), f"V quantizer _amax mismatch in {name_test}"
501-
512+
502513

503514
@pytest.mark.parametrize(
504515
"config",
@@ -509,16 +520,14 @@ def forward_fn(model):
509520
)
510521
def test_kv_cache_quant(config):
511522
"""Verify KV cache quantization works correctly with TEDotProductAttention.
512-
513-
This test ensures TEDotProductAttention is properly registered and gets the
523+
524+
This test ensures TEDotProductAttention is properly registered and gets the
514525
expected q/k/v_bmm_quantizers when using KV cache configs.
515-
526+
516527
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
517528
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
518529
"""
519-
spawn_multiprocess_job(
520-
size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl"
521-
)
530+
spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl")
522531

523532

524533
@pytest.mark.parametrize(
@@ -530,7 +539,7 @@ def test_kv_cache_quant(config):
530539
)
531540
def test_kv_cache_sharded_state_dict(tmp_path, config):
532541
"""Test KV cache quantization with sharded state dict save/load.
533-
542+
534543
This test verifies the complete workflow of saving and loading KV cache quantized
535544
models with distributed checkpointing, ensuring quantizer states are properly
536545
preserved across the save/load cycle.
@@ -539,5 +548,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config):
539548
spawn_multiprocess_job(
540549
size=size,
541550
job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config),
542-
backend="nccl"
551+
backend="nccl",
543552
)

0 commit comments

Comments
 (0)