diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index 531c2e476..ac63582bb 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -71,3 +71,6 @@ with import_plugin("trl"): from .trl import * + +with import_plugin("mcore"): + from .mcore import * diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index ab64a795a..d309c4436 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,6 +22,7 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -33,7 +34,8 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModuleRegistry, TensorQuantizer +from ..model_calib import max_calibrate +from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear @@ -460,3 +462,167 @@ class _RealQuantMegatronRowParallelLinear( def forward(self, input, *args, **kwargs): return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) + + +@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"}) +class _QuantTEDotProductAttention(QuantModule): + """Quantized version of TEDotProductAttention for Megatron models with KV cache quantization. + + This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention + module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer, + k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after + RoPE has been applied. + """ + + def _setup(self): + """Initialize quantizers for Q, K, V tensors.""" + self.q_bmm_quantizer = TensorQuantizer() + self.k_bmm_quantizer = TensorQuantizer() + self.v_bmm_quantizer = TensorQuantizer() + + def _calibrate_quantizers(self): + """Calibrate quantizers with minimal dummy tensors.""" + # Get device and dtype from the parent module's parameters + param = next(iter(self.parameters()), None) + device = param.device if param is not None else torch.device("cuda") + dtype = param.dtype if param is not None else torch.float16 + + # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion + batch_size = 1 + seq_len = 1 + + # Get dimensions from config + num_heads = self.config.num_attention_heads + head_dim = ( + self.config.kv_channels + if hasattr(self.config, "kv_channels") + else self.config.hidden_size // num_heads + ) + + # Determine tensor format (default to sbhd if not specified) + apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) + qkv_format = "bshd" if apply_rope_fusion else "sbhd" + + if qkv_format == "sbhd": + dummy_tensor = torch.randn( + seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype + ) + else: + dummy_tensor = torch.randn( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + + # Calibrate each quantizer + quantizers = [ + ("q_bmm_quantizer", self.q_bmm_quantizer), + ("k_bmm_quantizer", self.k_bmm_quantizer), + ("v_bmm_quantizer", self.v_bmm_quantizer), + ] + + for _, quantizer in quantizers: + if quantizer is not None and quantizer.is_enabled(): + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + quantizer.reset_amax() + max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) + + def forward(self, query, key, value, *args, **kwargs): + """Apply post-RoPE quantization to KV cache. + + TEDotProductAttention receives Q, K, V after RoPE is applied, + so we quantize them directly for KV cache quantization. + """ + # Quantize Q, K, V + query = self.q_bmm_quantizer(query) + key = self.k_bmm_quantizer(key) + value = self.v_bmm_quantizer(value) + + return super().forward(query, key, value, *args, **kwargs) + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + """Create a sharded state dictionary for distributed checkpointing.""" + sharded_state_dict = {} + + # First add non-quantizer parameters + for k, v in self.state_dict(prefix="", keep_vars=True).items(): + if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: + sharded_state_dict[prefix + k] = v + + # Process _amax in bmm_quantizers + for name, quantizer in [ + ("q_bmm_quantizer", self.q_bmm_quantizer), + ("k_bmm_quantizer", self.k_bmm_quantizer), + ("v_bmm_quantizer", self.v_bmm_quantizer), + ]: + if hasattr(quantizer, "_amax") and quantizer._amax is not None: + amax_key = f"{prefix}{name}._amax" + sharded_state_dict[amax_key] = quantizer._amax + + # Process other quantizer parameters in bmm_quantizers + quantizer_state_dict = { + k: v + for k, v in self.state_dict(prefix="", keep_vars=True).items() + if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k + } + + if quantizer_state_dict: + sharded_state_dict.update( + **make_sharded_tensors_for_checkpoint( + quantizer_state_dict, prefix, {}, sharded_offsets + ) + ) + + return sharded_state_dict + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + """Handle loading state dict for quantizers.""" + for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: + full_prefix = f"{prefix}{quantizer_name}." + amax_key = f"{prefix}{quantizer_name}._amax" + + # If amax is in state_dict, rename it to the format expected by TensorQuantizer + if amax_key in state_dict: + expected_amax_key = f"{full_prefix}_amax" + state_dict[expected_amax_key] = state_dict.pop(amax_key) + + # Handle other quantizer states + for k in list(state_dict.keys()): + if "_quantizer" in k and "_amax" not in k: + name = k.split(prefix)[-1] if prefix else k + if name in self.state_dict(): + state_dict[k] = state_dict[k].view_as(self.state_dict()[name]) + + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def modelopt_post_restore(self, name=""): + """Restore quantizer states after model loading.""" + super().modelopt_post_restore(name) + + def _check_unsupported_states(quantizer): + """Check for unsupported quantizer states and warn if found.""" + if not hasattr(quantizer, "state_dict"): + return + + for k in quantizer.state_dict(): + if k not in ["_amax", "_pre_quant_scale"]: + warnings.warn( + f"Restore of {k} for {name} is not supported. The restore of this layer might be " + f"incorrect. Please implement a custom restore for {k}." + ) + + calibration_needed = False + + for quantizer_name, quantizer in [ + ("q_bmm_quantizer", self.q_bmm_quantizer), + ("k_bmm_quantizer", self.k_bmm_quantizer), + ("v_bmm_quantizer", self.v_bmm_quantizer), + ]: + if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): + continue + + _check_unsupported_states(quantizer) + + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + calibration_needed = True + + if calibration_needed: + self._calibrate_quantizers() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index c3630e028..ffc778cb9 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -232,6 +232,8 @@ def forward_fn(model): mtq.W4A8_AWQ_BETA_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + # Note: KV cache configs (FP8_KV_CFG, NVFP4_KV_CFG) are tested separately in test_kv_cache_quant + # They require TEDotProductAttention which needs transformer_impl="modelopt", not "local" ], ) @pytest.mark.parametrize("compress", [False, True]) @@ -367,3 +369,184 @@ def forward_fn(model): def test_fp8_real_quantize(): size = torch.cuda.device_count() spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl") + + +def _test_kv_cache_quant_helper(config, rank, size): + """Helper function for testing KV cache quantization with TEDotProductAttention.""" + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + + # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention + # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention + model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + num_layers=1, + hidden_size=64, + num_attention_heads=4, + vocab_size=32, + transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec + ).cuda() + + # Create dummy input for calibration + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Test KV cache quantization with the given config + quantized_model = mtq.quantize(model, config, forward_fn) + + # Find TEDotProductAttention modules and verify they have KV cache quantizers + te_attention_found = False + for name, module in quantized_model.named_modules(): + # Check if this is a quantized TEDotProductAttention + if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): + te_attention_found = True + # Verify all expected quantizers exist + assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" + + # Verify K and V quantizers are enabled (main purpose of KV cache configs) + assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + + assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" + + # Quick smoke test that forward still works + output = forward_fn(quantized_model) + assert output is not None, "Forward pass failed" + + +def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): + """Helper for testing KV cache quantization with sharded state dict save/load.""" + # Disable output_layer quantization (same as other sharded state dict tests) + config["quant_cfg"]["*output_layer*"] = {"enable": False} + + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + + # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") + model_ref = get_mcore_gpt_model( + tensor_model_parallel_size=size, + num_layers=2, # At least 2 layers to test multiple attention modules + hidden_size=64, + num_attention_heads=4, + vocab_size=64, + transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention + ).cuda() + + model_test = get_mcore_gpt_model( + tensor_model_parallel_size=size, + num_layers=2, + hidden_size=64, + num_attention_heads=4, + vocab_size=64, + transformer_impl="modelopt", + ).cuda() + + prompt_tokens = torch.randint( + 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) + ).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Quantize the reference model + model_ref = mtq.quantize(model_ref, config, forward_fn) + + # CRITICAL: model_test must also be quantized with the same config + # Otherwise it won't have the KV cache quantizer keys when loading state dict + model_test = mtq.quantize(model_test, config, forward_fn) + + # Verify KV cache quantizers were created + kv_quantizers_found = False + for name, module in model_ref.named_modules(): + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + kv_quantizers_found = True + assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + + assert kv_quantizers_found, "No KV cache quantizers found in quantized model" + + # Test sharded state dict save/load + sharded_state_dict_test_helper( + tmp_path, + model_ref, + model_test, + forward_fn, + meta_device=False, + version=None, + ) + + # Verify KV cache quantizers are restored correctly in model_test + for (name_ref, module_ref), (name_test, module_test) in zip( + model_ref.named_modules(), model_test.named_modules() + ): + if hasattr(module_ref, "k_bmm_quantizer"): + assert hasattr(module_test, "k_bmm_quantizer"), ( + f"K quantizer missing after restore in {name_test}" + ) + assert hasattr(module_test, "v_bmm_quantizer"), ( + f"V quantizer missing after restore in {name_test}" + ) + + # Check that quantizer states match + if hasattr(module_ref.k_bmm_quantizer, "_amax"): + assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( + f"K quantizer _amax missing in {name_test}" + ) + if module_ref.k_bmm_quantizer._amax is not None: + assert torch.allclose( + module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax + ), f"K quantizer _amax mismatch in {name_test}" + + if hasattr(module_ref.v_bmm_quantizer, "_amax"): + assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( + f"V quantizer _amax missing in {name_test}" + ) + if module_ref.v_bmm_quantizer._amax is not None: + assert torch.allclose( + module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax + ), f"V quantizer _amax mismatch in {name_test}" + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_KV_CFG, + mtq.NVFP4_KV_CFG, + ], +) +def test_kv_cache_quant(config): + """Verify KV cache quantization works correctly with TEDotProductAttention. + + This test ensures TEDotProductAttention is properly registered and gets the + expected q/k/v_bmm_quantizers when using KV cache configs. + + Note: This test requires Transformer Engine to be installed since TEDotProductAttention + is only available with transformer_impl="modelopt" or "transformer_engine" (not "local"). + """ + spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl") + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_KV_CFG, + mtq.NVFP4_KV_CFG, + ], +) +def test_kv_cache_sharded_state_dict(tmp_path, config): + """Test KV cache quantization with sharded state dict save/load. + + This test verifies the complete workflow of saving and loading KV cache quantized + models with distributed checkpointing, ensuring quantizer states are properly + preserved across the save/load cycle. + """ + size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1 + spawn_multiprocess_job( + size=size, + job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), + backend="nccl", + )