Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@
from .trl import *

with import_plugin("mcore"):
from .mcore import *
from .mcore import *
116 changes: 92 additions & 24 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.utils import get_tensor_model_parallel_group_if_none
from megatron.core.extensions.transformer_engine import TEDotProductAttention

from modelopt.torch.opt.plugins.megatron import (
_MegatronMLP,
register_modelopt_extra_state_callbacks,
)
from modelopt.torch.utils.distributed import ParallelState

from ..model_calib import max_calibrate
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
from ..model_calib import max_calibrate

__all__ = []

Expand Down Expand Up @@ -461,6 +461,14 @@ class _RealQuantMegatronRowParallelLinear(
_scale_tensor_shard_axis = 1

def forward(self, input, *args, **kwargs):
"""
Compute the forward pass using the row-parallel linear implementation.

Forwards all positional and keyword arguments to the row-parallel parent implementation.

Returns:
torch.Tensor: The output activations produced by the linear layer.
"""
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)


Expand All @@ -469,32 +477,46 @@ class _QuantTEDotProductAttention(QuantModule):
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization."""

def _setup(self):
"""Initialize quantizers for Q, K, V tensors."""
"""
Create and attach three TensorQuantizer instances as q_bmm_quantizer, k_bmm_quantizer, and v_bmm_quantizer for quantizing query, key, and value tensors.
"""
self.q_bmm_quantizer = TensorQuantizer()
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()

def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device from parent module parameters
device = next(self.parameters()).device if self.parameters() else torch.device('cuda')
"""
Calibrate the module's Q/K/V tensor quantizers using minimal dummy inputs.

Creates a tiny float16 dummy tensor shaped according to the attention QKV layout (either "sbhd" or "bshd", determined from self.config.apply_rope_fusion) and uses it to compute and store `_amax` values for any enabled q_bmm_quantizer, k_bmm_quantizer, or v_bmm_quantizer that does not yet have an `_amax`. Calibration is performed only for quantizers that are enabled and lack existing scale information.
"""
# Get device from parent module parameters
device = next(self.parameters()).device if self.parameters() else torch.device("cuda")

# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1

# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads

head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)

# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"

if qkv_format == "sbhd":
dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16)
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16
)
else:
dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16)
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16
)

# Calibrate each quantizer
quantizers = [
Expand All @@ -510,10 +532,16 @@ def _calibrate_quantizers(self):
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)

def forward(self, query, key, value, *args, **kwargs):
"""Apply post-RoPE quantization to KV cache.
"""
Quantize the provided query, key, and value tensors for KV-cache and forward them to the base attention implementation.

Parameters:
query (Tensor): Query tensor (already rotated by RoPE) to be quantized and used for attention.
key (Tensor): Key tensor (already rotated by RoPE) to be quantized and used for attention.
value (Tensor): Value tensor to be quantized and used for attention.

TEDotProductAttention receives Q, K, V after RoPE is applied,
so we quantize them directly for KV cache quantization.
Returns:
The output of the parent attention `forward` called with the quantized query, key, and value.
"""
# Quantize Q, K, V
query = self.q_bmm_quantizer(query)
Expand All @@ -523,9 +551,22 @@ def forward(self, query, key, value, *args, **kwargs):
return super().forward(query, key, value, *args, **kwargs)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Create a sharded state dictionary for distributed checkpointing."""
sharded_state_dict = {}
"""
Builds a sharded state dictionary containing non-quantizer parameters and bmm-quantizer state for distributed checkpointing.

Parameters:
prefix (str): Key prefix to prepend to returned state keys.
sharded_offsets (tuple): Offsets describing shard positions for sharded tensors (passed to make_sharded_tensors_for_checkpoint).
metadata: Ignored by this implementation (kept for API compatibility).

Returns:
state_dict (dict): Mapping from checkpoint keys to tensors, including:
- Non-quantizer module tensors (prefixed).
- Per-quantizer `_amax` entries for q/k/v bmm quantizers when present.
- Other quantizer tensors processed into sharded tensors via the checkpoint helper.
"""
sharded_state_dict = {}

# First add non-quantizer parameters
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
Expand All @@ -542,10 +583,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
sharded_state_dict[amax_key] = quantizer._amax

# Process other quantizer parameters in bmm_quantizers
quantizer_state_dict = {}
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k:
quantizer_state_dict[k] = v
quantizer_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
}

if quantizer_state_dict:
sharded_state_dict.update(
Expand All @@ -557,7 +599,18 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
return sharded_state_dict

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Handle loading state dict for quantizers."""
"""
Adjust quantizer entries in a loaded state dict to match this module's expected keys and tensor shapes before delegating to the parent loader.

This method:
- Renames per-quantizer `_amax` keys from `{prefix}{quantizer_name}._amax` to `{prefix}{quantizer_name}._amax`'s expected TensorQuantizer key format (`{prefix}{quantizer_name}._amax` -> `{prefix}{quantizer_name}._amax` mapped to `{prefix}{quantizer_name}._amax` as `_amax` is remapped to `_{quantizer_name}_amax` format expected by the local TensorQuantizer).
- Reshapes any remaining quantizer state tensors (keys containing `_quantizer` but not `_amax`) to match the corresponding tensor shapes in this module's `state_dict`.
- Calls the superclass `_load_from_state_dict` with the adjusted `state_dict`.

Parameters:
state_dict (dict): The incoming state dictionary being loaded; modified in-place to align quantizer keys and shapes.
prefix (str): The prefix applied to keys for this module in `state_dict`.
"""
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
full_prefix = f"{prefix}{quantizer_name}."
amax_key = f"{prefix}{quantizer_name}._amax"
Expand All @@ -577,14 +630,29 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def modelopt_post_restore(self, name=""):
"""Restore quantizer states after model loading."""
"""
Perform post-restore validation for attention quantizers and trigger calibration if needed.

Checks each of the instance's Q/K/V BMM quantizers (if present and enabled) for unsupported saved state keys and emits a warning identifying the provided `name` when such keys are found. If any enabled quantizer lacks a stored `_amax` value, schedules and runs quantizer calibration by calling self._calibrate_quantizers().

Parameters:
name (str): Human-readable identifier for the module being restored; included in warning messages to help locate the layer.
"""
super().modelopt_post_restore(name)

def _check_unsupported_states(quantizer):
"""
Check a quantizer's saved state keys and warn about any unsupported entries.

Inspects quantizer.state_dict() (if present) and emits a warning for each key other than `_amax` and `_pre_quant_scale` indicating that restoring that key is not supported.

Parameters:
quantizer: An object with a `state_dict()` method (typically a TensorQuantizer) whose saved state keys will be validated.
"""
if not hasattr(quantizer, "state_dict"):
return

for k in quantizer.state_dict().keys():
for k in quantizer.state_dict():
if k not in ["_amax", "_pre_quant_scale"]:
warnings.warn(
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
Expand Down
Loading