From 4efc618be6fdb77644e76ec5a9b671438cc5fbba Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 25 Sep 2025 12:08:27 -0700 Subject: [PATCH 1/3] add kv cache quantization for mcore using bmm_quantizers Signed-off-by: Kai Xu --- .../torch/quantization/plugins/__init__.py | 3 + .../torch/quantization/plugins/megatron.py | 150 ++++++++++++++- .../quantization/plugins/test_megatron.py | 174 ++++++++++++++++++ 3 files changed, 326 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index 531c2e476..d1451e37d 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -71,3 +71,6 @@ with import_plugin("trl"): from .trl import * + +with import_plugin("mcore"): + from .mcore import * \ No newline at end of file diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 1cf9416ec..c5d4146e6 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -26,6 +26,7 @@ from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none +from megatron.core.extensions.transformer_engine import TEDotProductAttention from modelopt.torch.opt.plugins.megatron import ( _MegatronMLP, @@ -33,10 +34,11 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModuleRegistry, TensorQuantizer +from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +from ..model_calib import max_calibrate __all__ = [] @@ -462,3 +464,149 @@ class _RealQuantMegatronRowParallelLinear( def forward(self, input, *args, **kwargs): return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) + + +@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"}) +class _QuantTEDotProductAttention(QuantModule): + """Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.""" + + def _setup(self): + """Initialize quantizers for Q, K, V tensors.""" + self.q_bmm_quantizer = TensorQuantizer() + self.k_bmm_quantizer = TensorQuantizer() + self.v_bmm_quantizer = TensorQuantizer() + + def _calibrate_quantizers(self): + """Calibrate quantizers with minimal dummy tensors.""" + # Get device from parent module parameters + device = next(self.parameters()).device if self.parameters() else torch.device('cuda') + + # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion + batch_size = 1 + seq_len = 1 + + # Get dimensions from config + num_heads = self.config.num_attention_heads + head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads + + # Determine tensor format (default to sbhd if not specified) + apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False) + qkv_format = "bshd" if apply_rope_fusion else "sbhd" + + if qkv_format == "sbhd": + dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16) + else: + dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16) + + # Calibrate each quantizer + quantizers = [ + ("q_bmm_quantizer", self.q_bmm_quantizer), + ("k_bmm_quantizer", self.k_bmm_quantizer), + ("v_bmm_quantizer", self.v_bmm_quantizer), + ] + + for _, quantizer in quantizers: + if quantizer is not None and quantizer.is_enabled: + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + quantizer.reset_amax() + max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) + + def forward(self, query, key, value, *args, **kwargs): + """Apply post-RoPE quantization to KV cache. + + TEDotProductAttention receives Q, K, V after RoPE is applied, + so we quantize them directly for KV cache quantization. + """ + # Quantize Q, K, V + query = self.q_bmm_quantizer(query) + key = self.k_bmm_quantizer(key) + value = self.v_bmm_quantizer(value) + + return super().forward(query, key, value, *args, **kwargs) + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + """Create a sharded state dictionary for distributed checkpointing.""" + sharded_state_dict = {} + + # First add non-quantizer parameters + for k, v in self.state_dict(prefix="", keep_vars=True).items(): + if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: + sharded_state_dict[prefix + k] = v + + # Process _amax in bmm_quantizers + for name, quantizer in [ + ("q_bmm_quantizer", self.q_bmm_quantizer), + ("k_bmm_quantizer", self.k_bmm_quantizer), + ("v_bmm_quantizer", self.v_bmm_quantizer), + ]: + if hasattr(quantizer, "_amax") and quantizer._amax is not None: + amax_key = f"{prefix}{name}._amax" + sharded_state_dict[amax_key] = quantizer._amax + + # Process other quantizer parameters in bmm_quantizers + quantizer_state_dict = {} + for k, v in self.state_dict(prefix="", keep_vars=True).items(): + if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k: + quantizer_state_dict[k] = v + + if quantizer_state_dict: + sharded_state_dict.update( + **make_sharded_tensors_for_checkpoint( + quantizer_state_dict, prefix, {}, sharded_offsets + ) + ) + + return sharded_state_dict + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + """Handle loading state dict for quantizers.""" + for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: + full_prefix = f"{prefix}{quantizer_name}." + amax_key = f"{prefix}{quantizer_name}._amax" + + # If amax is in state_dict, rename it to the format expected by TensorQuantizer + if amax_key in state_dict: + expected_amax_key = f"{full_prefix}_amax" + state_dict[expected_amax_key] = state_dict.pop(amax_key) + + # Handle other quantizer states + for k in list(state_dict.keys()): + if "_quantizer" in k and "_amax" not in k: + name = k.split(prefix)[-1] if prefix else k + if name in self.state_dict(): + state_dict[k] = state_dict[k].view_as(self.state_dict()[name]) + + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def modelopt_post_restore(self, name=""): + """Restore quantizer states after model loading.""" + super().modelopt_post_restore(name) + + def _check_unsupported_states(quantizer): + if not hasattr(quantizer, "state_dict"): + return + + for k in quantizer.state_dict().keys(): + if k not in ["_amax", "_pre_quant_scale"]: + warnings.warn( + f"Restore of {k} for {name} is not supported. The restore of this layer might be " + f"incorrect. Please implement a custom restore for {k}." + ) + + calibration_needed = False + + for quantizer_name, quantizer in [ + ("q_bmm_quantizer", self.q_bmm_quantizer), + ("k_bmm_quantizer", self.k_bmm_quantizer), + ("v_bmm_quantizer", self.v_bmm_quantizer), + ]: + if not hasattr(self, quantizer_name) or not quantizer.is_enabled: + continue + + _check_unsupported_states(quantizer) + + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + calibration_needed = True + + if calibration_needed: + self._calibrate_quantizers() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 226403ea2..b0f5a3e24 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -17,6 +17,7 @@ import pytest import torch +import torch.nn as nn from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( @@ -230,6 +231,8 @@ def forward_fn(model): mtq.W4A8_AWQ_BETA_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + # Note: KV cache configs (FP8_KV_CFG, NVFP4_KV_CFG) are tested separately in test_kv_cache_quant + # They require TEDotProductAttention which needs transformer_impl="modelopt", not "local" ], ) @pytest.mark.parametrize("compress", [False, True]) @@ -361,3 +364,174 @@ def forward_fn(model): def test_fp8_real_quantize(): size = torch.cuda.device_count() spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl") + + +def _test_kv_cache_quant_helper(config, rank, size): + """Helper function for testing KV cache quantization with TEDotProductAttention.""" + initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) + + # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention + # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention + model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + num_layers=1, + hidden_size=64, + num_attention_heads=4, + vocab_size=32, + transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec + ).cuda() + + # Create dummy input for calibration + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Test KV cache quantization with the given config + quantized_model = mtq.quantize(model, config, forward_fn) + + # Find TEDotProductAttention modules and verify they have KV cache quantizers + te_attention_found = False + for name, module in quantized_model.named_modules(): + # Check if this is a quantized TEDotProductAttention + if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'): + te_attention_found = True + # Verify all expected quantizers exist + assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}" + + # Verify K and V quantizers are enabled (main purpose of KV cache configs) + assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + + assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" + + # Quick smoke test that forward still works + output = forward_fn(quantized_model) + assert output is not None, "Forward pass failed" + + +def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): + """Helper for testing KV cache quantization with sharded state dict save/load.""" + # Disable output_layer quantization (same as other sharded state dict tests) + config["quant_cfg"]["*output_layer*"] = {"enable": False} + + initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) + + # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") + model_ref = get_mcore_gpt_model( + tensor_model_parallel_size=size, + num_layers=2, # At least 2 layers to test multiple attention modules + hidden_size=64, + num_attention_heads=4, + vocab_size=64, + transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention + ).cuda() + + model_test = get_mcore_gpt_model( + tensor_model_parallel_size=size, + num_layers=2, + hidden_size=64, + num_attention_heads=4, + vocab_size=64, + transformer_impl="modelopt", + ).cuda() + + prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Quantize the reference model + model_ref = mtq.quantize(model_ref, config, forward_fn) + + # CRITICAL: model_test must also be quantized with the same config + # Otherwise it won't have the KV cache quantizer keys when loading state dict + model_test = mtq.quantize(model_test, config, forward_fn) + + # Verify KV cache quantizers were created + kv_quantizers_found = False + for name, module in model_ref.named_modules(): + if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'): + kv_quantizers_found = True + assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + + assert kv_quantizers_found, "No KV cache quantizers found in quantized model" + + # Test sharded state dict save/load + sharded_state_dict_test_helper( + tmp_path, + model_ref, + model_test, + forward_fn, + meta_device=False, + version=None, + ) + + # Verify KV cache quantizers are restored correctly in model_test + for (name_ref, module_ref), (name_test, module_test) in zip( + model_ref.named_modules(), model_test.named_modules() + ): + if hasattr(module_ref, 'k_bmm_quantizer'): + assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}" + assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}" + + # Check that quantizer states match + if hasattr(module_ref.k_bmm_quantizer, '_amax'): + assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}" + if module_ref.k_bmm_quantizer._amax is not None: + assert torch.allclose( + module_ref.k_bmm_quantizer._amax, + module_test.k_bmm_quantizer._amax + ), f"K quantizer _amax mismatch in {name_test}" + + if hasattr(module_ref.v_bmm_quantizer, '_amax'): + assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}" + if module_ref.v_bmm_quantizer._amax is not None: + assert torch.allclose( + module_ref.v_bmm_quantizer._amax, + module_test.v_bmm_quantizer._amax + ), f"V quantizer _amax mismatch in {name_test}" + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_KV_CFG, + mtq.NVFP4_KV_CFG, + ], +) +def test_kv_cache_quant(config): + """Verify KV cache quantization works correctly with TEDotProductAttention. + + This test ensures TEDotProductAttention is properly registered and gets the + expected q/k/v_bmm_quantizers when using KV cache configs. + + Note: This test requires Transformer Engine to be installed since TEDotProductAttention + is only available with transformer_impl="modelopt" or "transformer_engine" (not "local"). + """ + spawn_multiprocess_job( + size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl" + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_KV_CFG, + mtq.NVFP4_KV_CFG, + ], +) +def test_kv_cache_sharded_state_dict(tmp_path, config): + """Test KV cache quantization with sharded state dict save/load. + + This test verifies the complete workflow of saving and loading KV cache quantized + models with distributed checkpointing, ensuring quantizer states are properly + preserved across the save/load cycle. + """ + size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1 + spawn_multiprocess_job( + size=size, + job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), + backend="nccl" + ) From 1f77518f5e3e001c498d7323cfbdcf3b67b8886c Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 25 Sep 2025 12:11:53 -0700 Subject: [PATCH 2/3] add kv cache quantization for mcore using bmm_quantizers Signed-off-by: Kai Xu --- .../torch/quantization/plugins/__init__.py | 2 +- .../torch/quantization/plugins/megatron.py | 60 ++++++---- .../quantization/plugins/test_megatron.py | 107 ++++++++++-------- 3 files changed, 98 insertions(+), 71 deletions(-) diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index d1451e37d..ac63582bb 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -73,4 +73,4 @@ from .trl import * with import_plugin("mcore"): - from .mcore import * \ No newline at end of file + from .mcore import * diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index c5d4146e6..1f6c27b4b 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,11 +22,11 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none -from megatron.core.extensions.transformer_engine import TEDotProductAttention from modelopt.torch.opt.plugins.megatron import ( _MegatronMLP, @@ -34,11 +34,11 @@ ) from modelopt.torch.utils.distributed import ParallelState +from ..model_calib import max_calibrate from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear -from ..model_calib import max_calibrate __all__ = [] @@ -468,7 +468,13 @@ def forward(self, input, *args, **kwargs): @QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"}) class _QuantTEDotProductAttention(QuantModule): - """Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.""" + """Quantized version of TEDotProductAttention for Megatron models with KV cache quantization. + + This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention + module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer, + k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after + RoPE has been applied. + """ def _setup(self): """Initialize quantizers for Q, K, V tensors.""" @@ -478,25 +484,35 @@ def _setup(self): def _calibrate_quantizers(self): """Calibrate quantizers with minimal dummy tensors.""" - # Get device from parent module parameters - device = next(self.parameters()).device if self.parameters() else torch.device('cuda') - + # Get device and dtype from the parent module's parameters + param = next(iter(self.parameters()), None) + device = param.device if param is not None else torch.device("cuda") + dtype = param.dtype if param is not None else torch.float16 + # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion batch_size = 1 seq_len = 1 - + # Get dimensions from config num_heads = self.config.num_attention_heads - head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads - + head_dim = ( + self.config.kv_channels + if hasattr(self.config, "kv_channels") + else self.config.hidden_size // num_heads + ) + # Determine tensor format (default to sbhd if not specified) - apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False) + apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) qkv_format = "bshd" if apply_rope_fusion else "sbhd" if qkv_format == "sbhd": - dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16) + dummy_tensor = torch.randn( + seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype + ) else: - dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16) + dummy_tensor = torch.randn( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype + ) # Calibrate each quantizer quantizers = [ @@ -506,14 +522,14 @@ def _calibrate_quantizers(self): ] for _, quantizer in quantizers: - if quantizer is not None and quantizer.is_enabled: + if quantizer is not None and quantizer.is_enabled(): if not hasattr(quantizer, "_amax") or quantizer._amax is None: quantizer.reset_amax() max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) def forward(self, query, key, value, *args, **kwargs): """Apply post-RoPE quantization to KV cache. - + TEDotProductAttention receives Q, K, V after RoPE is applied, so we quantize them directly for KV cache quantization. """ @@ -527,7 +543,7 @@ def forward(self, query, key, value, *args, **kwargs): def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Create a sharded state dictionary for distributed checkpointing.""" sharded_state_dict = {} - + # First add non-quantizer parameters for k, v in self.state_dict(prefix="", keep_vars=True).items(): if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: @@ -544,10 +560,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): sharded_state_dict[amax_key] = quantizer._amax # Process other quantizer parameters in bmm_quantizers - quantizer_state_dict = {} - for k, v in self.state_dict(prefix="", keep_vars=True).items(): - if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k: - quantizer_state_dict[k] = v + quantizer_state_dict = { + k: v + for k, v in self.state_dict(prefix="", keep_vars=True).items() + if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k + } if quantizer_state_dict: sharded_state_dict.update( @@ -583,10 +600,11 @@ def modelopt_post_restore(self, name=""): super().modelopt_post_restore(name) def _check_unsupported_states(quantizer): + """Check for unsupported quantizer states and warn if found.""" if not hasattr(quantizer, "state_dict"): return - for k in quantizer.state_dict().keys(): + for k in quantizer.state_dict(): if k not in ["_amax", "_pre_quant_scale"]: warnings.warn( f"Restore of {k} for {name} is not supported. The restore of this layer might be " @@ -600,7 +618,7 @@ def _check_unsupported_states(quantizer): ("k_bmm_quantizer", self.k_bmm_quantizer), ("v_bmm_quantizer", self.v_bmm_quantizer), ]: - if not hasattr(self, quantizer_name) or not quantizer.is_enabled: + if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): continue _check_unsupported_states(quantizer) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index b0f5a3e24..3a0501c8d 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -17,7 +17,6 @@ import pytest import torch -import torch.nn as nn from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( @@ -368,8 +367,10 @@ def test_fp8_real_quantize(): def _test_kv_cache_quant_helper(config, rank, size): """Helper function for testing KV cache quantization with TEDotProductAttention.""" - initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) - + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention model = get_mcore_gpt_model( @@ -380,43 +381,45 @@ def _test_kv_cache_quant_helper(config, rank, size): vocab_size=32, transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec ).cuda() - + # Create dummy input for calibration prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - + def forward_fn(model): return megatron_prefill(model, prompt_tokens) - + # Test KV cache quantization with the given config quantized_model = mtq.quantize(model, config, forward_fn) - + # Find TEDotProductAttention modules and verify they have KV cache quantizers te_attention_found = False for name, module in quantized_model.named_modules(): # Check if this is a quantized TEDotProductAttention - if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'): + if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): te_attention_found = True # Verify all expected quantizers exist - assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}" - + assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" + # Verify K and V quantizers are enabled (main purpose of KV cache configs) assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - + assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" - + # Quick smoke test that forward still works output = forward_fn(quantized_model) assert output is not None, "Forward pass failed" - + def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): """Helper for testing KV cache quantization with sharded state dict save/load.""" # Disable output_layer quantization (same as other sharded state dict tests) config["quant_cfg"]["*output_layer*"] = {"enable": False} - - initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) - + + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") model_ref = get_mcore_gpt_model( tensor_model_parallel_size=size, @@ -426,7 +429,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): vocab_size=64, transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention ).cuda() - + model_test = get_mcore_gpt_model( tensor_model_parallel_size=size, num_layers=2, @@ -435,29 +438,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): vocab_size=64, transformer_impl="modelopt", ).cuda() - - prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda() - + + prompt_tokens = torch.randint( + 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) + ).cuda() + def forward_fn(model): return megatron_prefill(model, prompt_tokens) - + # Quantize the reference model model_ref = mtq.quantize(model_ref, config, forward_fn) - + # CRITICAL: model_test must also be quantized with the same config # Otherwise it won't have the KV cache quantizer keys when loading state dict model_test = mtq.quantize(model_test, config, forward_fn) - + # Verify KV cache quantizers were created kv_quantizers_found = False for name, module in model_ref.named_modules(): - if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'): + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): kv_quantizers_found = True assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - + assert kv_quantizers_found, "No KV cache quantizers found in quantized model" - + # Test sharded state dict save/load sharded_state_dict_test_helper( tmp_path, @@ -467,32 +472,38 @@ def forward_fn(model): meta_device=False, version=None, ) - + # Verify KV cache quantizers are restored correctly in model_test for (name_ref, module_ref), (name_test, module_test) in zip( model_ref.named_modules(), model_test.named_modules() ): - if hasattr(module_ref, 'k_bmm_quantizer'): - assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}" - assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}" - + if hasattr(module_ref, "k_bmm_quantizer"): + assert hasattr(module_test, "k_bmm_quantizer"), ( + f"K quantizer missing after restore in {name_test}" + ) + assert hasattr(module_test, "v_bmm_quantizer"), ( + f"V quantizer missing after restore in {name_test}" + ) + # Check that quantizer states match - if hasattr(module_ref.k_bmm_quantizer, '_amax'): - assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}" + if hasattr(module_ref.k_bmm_quantizer, "_amax"): + assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( + f"K quantizer _amax missing in {name_test}" + ) if module_ref.k_bmm_quantizer._amax is not None: assert torch.allclose( - module_ref.k_bmm_quantizer._amax, - module_test.k_bmm_quantizer._amax + module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax ), f"K quantizer _amax mismatch in {name_test}" - - if hasattr(module_ref.v_bmm_quantizer, '_amax'): - assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}" + + if hasattr(module_ref.v_bmm_quantizer, "_amax"): + assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( + f"V quantizer _amax missing in {name_test}" + ) if module_ref.v_bmm_quantizer._amax is not None: assert torch.allclose( - module_ref.v_bmm_quantizer._amax, - module_test.v_bmm_quantizer._amax + module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax ), f"V quantizer _amax mismatch in {name_test}" - + @pytest.mark.parametrize( "config", @@ -503,16 +514,14 @@ def forward_fn(model): ) def test_kv_cache_quant(config): """Verify KV cache quantization works correctly with TEDotProductAttention. - - This test ensures TEDotProductAttention is properly registered and gets the + + This test ensures TEDotProductAttention is properly registered and gets the expected q/k/v_bmm_quantizers when using KV cache configs. - + Note: This test requires Transformer Engine to be installed since TEDotProductAttention is only available with transformer_impl="modelopt" or "transformer_engine" (not "local"). """ - spawn_multiprocess_job( - size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl" - ) + spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl") @pytest.mark.parametrize( @@ -524,7 +533,7 @@ def test_kv_cache_quant(config): ) def test_kv_cache_sharded_state_dict(tmp_path, config): """Test KV cache quantization with sharded state dict save/load. - + This test verifies the complete workflow of saving and loading KV cache quantized models with distributed checkpointing, ensuring quantizer states are properly preserved across the save/load cycle. @@ -533,5 +542,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config): spawn_multiprocess_job( size=size, job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), - backend="nccl" + backend="nccl", ) From d2c05f2d411227b4371508b0bc4cd5a8a11f1b91 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 25 Sep 2025 15:34:41 -0700 Subject: [PATCH 3/3] update CHANGELOG Signed-off-by: Kai Xu --- CHANGELOG.rst | 8 ++++++++ modelopt/torch/quantization/plugins/__init__.py | 3 --- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c96f9048e..4de7f2844 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ Model Optimizer Changelog (Linux) ================================= +0.41 (2025-12-xx) +^^^^^^^^^^^^^^^^^ + +**Deprecations** + +**New Features** +- Add FP8/NVFP4 KV cache quantization support for Megatron Core models. + 0.39 (2025-11-xx) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index ac63582bb..531c2e476 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -71,6 +71,3 @@ with import_plugin("trl"): from .trl import * - -with import_plugin("mcore"): - from .mcore import *