Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modelopt/torch/quantization/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@

with import_plugin("trl"):
from .trl import *

with import_plugin("mcore"):
from .mcore import *
168 changes: 167 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.extensions.transformer_engine import TEDotProductAttention
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard TEDotProductAttention import when TE is unavailable

Line 25 introduces an unconditional import of TEDotProductAttention. In environments where Transformer Engine is not installed (which is a supported configuration for the rest of this plugin), this raises ModuleNotFoundError during import and breaks all Megatron quantization paths that previously worked without TE. Please wrap the import in a try/except and only register _QuantTEDotProductAttention when the symbol is available.

-from megatron.core.extensions.transformer_engine import TEDotProductAttention
+try:
+    from megatron.core.extensions.transformer_engine import TEDotProductAttention
+except ImportError:
+    TEDotProductAttention = None

And guard the registration/class definition behind if TEDotProductAttention is not None: so the module remains usable without TE.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from megatron.core.extensions.transformer_engine import TEDotProductAttention
try:
from megatron.core.extensions.transformer_engine import TEDotProductAttention
except ImportError:
TEDotProductAttention = None
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around line 25, the
unconditional import of TEDotProductAttention will raise ModuleNotFoundError
when Transformer Engine (TE) is not installed; wrap the import in a try/except
that sets TEDotProductAttention = None on ImportError, then only define and
register the _QuantTEDotProductAttention class (and any related registration
calls) inside an if TEDotProductAttention is not None: block so the module
remains importable and Megatron quantization works without TE.

from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand All @@ -33,7 +34,8 @@
)
from modelopt.torch.utils.distributed import ParallelState

from ..nn import QuantModuleRegistry, TensorQuantizer
from ..model_calib import max_calibrate
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
Expand Down Expand Up @@ -460,3 +462,167 @@ class _RealQuantMegatronRowParallelLinear(

def forward(self, input, *args, **kwargs):
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)


@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"})
class _QuantTEDotProductAttention(QuantModule):
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.

This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention
module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer,
k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after
RoPE has been applied.
"""

def _setup(self):
"""Initialize quantizers for Q, K, V tensors."""
self.q_bmm_quantizer = TensorQuantizer()
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()

def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device and dtype from the parent module's parameters
param = next(iter(self.parameters()), None)
device = param.device if param is not None else torch.device("cuda")
dtype = param.dtype if param is not None else torch.float16

# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1

# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)

# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"

if qkv_format == "sbhd":
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
)
else:
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
)

# Calibrate each quantizer
quantizers = [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]

for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled():
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)

def forward(self, query, key, value, *args, **kwargs):
"""Apply post-RoPE quantization to KV cache.

TEDotProductAttention receives Q, K, V after RoPE is applied,
so we quantize them directly for KV cache quantization.
"""
# Quantize Q, K, V
query = self.q_bmm_quantizer(query)
key = self.k_bmm_quantizer(key)
value = self.v_bmm_quantizer(value)

return super().forward(query, key, value, *args, **kwargs)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Create a sharded state dictionary for distributed checkpointing."""
sharded_state_dict = {}

# First add non-quantizer parameters
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
sharded_state_dict[prefix + k] = v

# Process _amax in bmm_quantizers
for name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
amax_key = f"{prefix}{name}._amax"
sharded_state_dict[amax_key] = quantizer._amax

# Process other quantizer parameters in bmm_quantizers
quantizer_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
}

if quantizer_state_dict:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
quantizer_state_dict, prefix, {}, sharded_offsets
)
)

return sharded_state_dict

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Handle loading state dict for quantizers."""
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
full_prefix = f"{prefix}{quantizer_name}."
amax_key = f"{prefix}{quantizer_name}._amax"

# If amax is in state_dict, rename it to the format expected by TensorQuantizer
if amax_key in state_dict:
expected_amax_key = f"{full_prefix}_amax"
state_dict[expected_amax_key] = state_dict.pop(amax_key)

# Handle other quantizer states
for k in list(state_dict.keys()):
if "_quantizer" in k and "_amax" not in k:
name = k.split(prefix)[-1] if prefix else k
if name in self.state_dict():
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def modelopt_post_restore(self, name=""):
"""Restore quantizer states after model loading."""
super().modelopt_post_restore(name)

def _check_unsupported_states(quantizer):
"""Check for unsupported quantizer states and warn if found."""
if not hasattr(quantizer, "state_dict"):
return

for k in quantizer.state_dict():
if k not in ["_amax", "_pre_quant_scale"]:
warnings.warn(
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
f"incorrect. Please implement a custom restore for {k}."
)

calibration_needed = False

for quantizer_name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
continue

_check_unsupported_states(quantizer)

if not hasattr(quantizer, "_amax") or quantizer._amax is None:
calibration_needed = True

if calibration_needed:
self._calibrate_quantizers()
183 changes: 183 additions & 0 deletions tests/gpu/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def forward_fn(model):
mtq.W4A8_AWQ_BETA_CFG,
mtq.NVFP4_DEFAULT_CFG,
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
# Note: KV cache configs (FP8_KV_CFG, NVFP4_KV_CFG) are tested separately in test_kv_cache_quant
# They require TEDotProductAttention which needs transformer_impl="modelopt", not "local"
],
)
@pytest.mark.parametrize("compress", [False, True])
Expand Down Expand Up @@ -367,3 +369,184 @@ def forward_fn(model):
def test_fp8_real_quantize():
size = torch.cuda.device_count()
spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl")


def _test_kv_cache_quant_helper(config, rank, size):
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)

# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
model = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=1,
hidden_size=64,
num_attention_heads=4,
vocab_size=32,
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
).cuda()

# Create dummy input for calibration
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()

def forward_fn(model):
return megatron_prefill(model, prompt_tokens)

# Test KV cache quantization with the given config
quantized_model = mtq.quantize(model, config, forward_fn)

# Find TEDotProductAttention modules and verify they have KV cache quantizers
te_attention_found = False
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"

# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"

assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"

# Quick smoke test that forward still works
output = forward_fn(quantized_model)
assert output is not None, "Forward pass failed"

Comment on lines +374 to +418
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix assertions: call TensorQuantizer.is_enabled as a method.

is_enabled is a method, not a property. As written, the assertions will always pass because a bound method is truthy.

Apply:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _test_kv_cache_quant_helper(config, rank, size):
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
model = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=1,
hidden_size=64,
num_attention_heads=4,
vocab_size=32,
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
).cuda()
# Create dummy input for calibration
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
def forward_fn(model):
return megatron_prefill(model, prompt_tokens)
# Test KV cache quantization with the given config
quantized_model = mtq.quantize(model, config, forward_fn)
# Find TEDotProductAttention modules and verify they have KV cache quantizers
te_attention_found = False
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
# Quick smoke test that forward still works
output = forward_fn(quantized_model)
assert output is not None, "Forward pass failed"
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 374 to
418, the assertions check TensorQuantizer.is_enabled as an attribute
(module.k_bmm_quantizer.is_enabled and module.v_bmm_quantizer.is_enabled) but
is_enabled is a method; change those assertions to call the method
(module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled()) so
they evaluate correctly.


def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
"""Helper for testing KV cache quantization with sharded state dict save/load."""
# Disable output_layer quantization (same as other sharded state dict tests)
config["quant_cfg"]["*output_layer*"] = {"enable": False}

Comment on lines +423 to +424
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not mutate shared config constants in-place

Line 423 mutates the shared config object (mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG) by updating quant_cfg. Because those constants are reused across parametrized tests (and potentially by library consumers), this causes cross-test contamination and unexpected behavior later in the suite. Make a copy before modifying.

-def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+    config = copy.deepcopy(config)
     """Helper for testing KV cache quantization with sharded state dict save/load."""
     # Disable output_layer quantization (same as other sharded state dict tests)
-    config["quant_cfg"]["*output_layer*"] = {"enable": False}
+    config["quant_cfg"]["*output_layer*"] = {"enable": False}
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 423-424,
the test mutates the shared config constant by assigning
config["quant_cfg"]["*output_layer*"] = {"enable": False}; instead, avoid
in-place mutation by making a copy (use copy.deepcopy(config) or dict deepcopy)
into a local variable and mutate that copy (or create a shallow copy of
quant_cfg and update it) so the original mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG
remain unchanged; replace subsequent uses in the test with the copied config.

initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)

# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
model_ref = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2, # At least 2 layers to test multiple attention modules
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
).cuda()

model_test = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2,
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt",
).cuda()

prompt_tokens = torch.randint(
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
).cuda()

def forward_fn(model):
return megatron_prefill(model, prompt_tokens)

# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)

# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"

assert kv_quantizers_found, "No KV cache quantizers found in quantized model"

# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
model_ref,
model_test,
forward_fn,
meta_device=False,
version=None,
)

# Verify KV cache quantizers are restored correctly in model_test
for (name_ref, module_ref), (name_test, module_test) in zip(
model_ref.named_modules(), model_test.named_modules()
):
if hasattr(module_ref, "k_bmm_quantizer"):
assert hasattr(module_test, "k_bmm_quantizer"), (
f"K quantizer missing after restore in {name_test}"
)
assert hasattr(module_test, "v_bmm_quantizer"), (
f"V quantizer missing after restore in {name_test}"
)

# Check that quantizer states match
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
f"K quantizer _amax missing in {name_test}"
)
if module_ref.k_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
), f"K quantizer _amax mismatch in {name_test}"

if hasattr(module_ref.v_bmm_quantizer, "_amax"):
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
f"V quantizer _amax missing in {name_test}"
)
if module_ref.v_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
), f"V quantizer _amax mismatch in {name_test}"


Comment on lines +420 to +513
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix is_enabled checks in sharded state dict helper.

Same issue: call is_enabled().

Apply:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
"""Helper for testing KV cache quantization with sharded state dict save/load."""
# Disable output_layer quantization (same as other sharded state dict tests)
config["quant_cfg"]["*output_layer*"] = {"enable": False}
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
model_ref = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2, # At least 2 layers to test multiple attention modules
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
).cuda()
model_test = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2,
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt",
).cuda()
prompt_tokens = torch.randint(
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
).cuda()
def forward_fn(model):
return megatron_prefill(model, prompt_tokens)
# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)
# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)
# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
model_ref,
model_test,
forward_fn,
meta_device=False,
version=None,
)
# Verify KV cache quantizers are restored correctly in model_test
for (name_ref, module_ref), (name_test, module_test) in zip(
model_ref.named_modules(), model_test.named_modules()
):
if hasattr(module_ref, "k_bmm_quantizer"):
assert hasattr(module_test, "k_bmm_quantizer"), (
f"K quantizer missing after restore in {name_test}"
)
assert hasattr(module_test, "v_bmm_quantizer"), (
f"V quantizer missing after restore in {name_test}"
)
# Check that quantizer states match
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
f"K quantizer _amax missing in {name_test}"
)
if module_ref.k_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
), f"K quantizer _amax mismatch in {name_test}"
if hasattr(module_ref.v_bmm_quantizer, "_amax"):
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
f"V quantizer _amax missing in {name_test}"
)
if module_ref.v_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
), f"V quantizer _amax mismatch in {name_test}"
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 420-513,
the assertions that verify KV cache quantizers use the attribute access
.is_enabled instead of calling the method; update those assertions to call
module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled() so
the checks invoke the method rather than reference a non-callable attribute.

@pytest.mark.parametrize(
"config",
[
mtq.FP8_KV_CFG,
mtq.NVFP4_KV_CFG,
],
)
def test_kv_cache_quant(config):
"""Verify KV cache quantization works correctly with TEDotProductAttention.

This test ensures TEDotProductAttention is properly registered and gets the
expected q/k/v_bmm_quantizers when using KV cache configs.

Note: This test requires Transformer Engine to be installed since TEDotProductAttention
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
"""
spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl")


@pytest.mark.parametrize(
"config",
[
mtq.FP8_KV_CFG,
mtq.NVFP4_KV_CFG,
],
)
def test_kv_cache_sharded_state_dict(tmp_path, config):
"""Test KV cache quantization with sharded state dict save/load.

This test verifies the complete workflow of saving and loading KV cache quantized
models with distributed checkpointing, ensuring quantizer states are properly
preserved across the save/load cycle.
"""
size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1
spawn_multiprocess_job(
size=size,
job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config),
backend="nccl",
)
Loading