Skip to content

Conversation

kaix-nv
Copy link

@kaix-nv kaix-nv commented Sep 25, 2025

What does this PR do?

Type of change: ?
New feature

Overview: ?

Usage

pytest tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant -v
pytest tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict -v

Testing

tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config1]
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config0]

tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config0] PASSED
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config1] PASSED

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

  • New Features

    • KV-cache quantization support for TEDotProductAttention in Megatron GPT models, with per-tensor Q/K/V quantization and calibration.
    • Sharded checkpointing now preserves KV quantizer state (including amax) for reliable restore.
    • Plugin auto-loading extended to include the Megatron quantization integration.
  • Tests

    • Added tests covering KV-cache quantization and sharded save/load restore for TEDotProductAttention under Megatron configurations.

@kaix-nv kaix-nv requested a review from a team as a code owner September 25, 2025 19:29
@kaix-nv kaix-nv requested a review from cjluo-nv September 25, 2025 19:29
Copy link

copy-pr-bot bot commented Sep 25, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 25, 2025

Walkthrough

Adds mcore plugin import; introduces a QuantModule _QuantTEDotProductAttention for Megatron TEDotProductAttention with KV-cache quantization, checkpoint save/load and post-restore handling; and adds GPU tests verifying KV quantization and sharded-state-dict restore. Duplicate class block present in patch.

Changes

Cohort / File(s) Summary of Changes
Plugin registration / imports
modelopt/torch/quantization/plugins/__init__.py
Adds import block to load the mcore plugin via import_plugin context manager, matching existing plugin imports.
Megatron TEDotProductAttention quantization
modelopt/torch/quantization/plugins/megatron.py
Adds _QuantTEDotProductAttention registered via QuantModuleRegistry for TEDotProductAttention: quantizer setup for Q/K/V, calibration (_calibrate_quantizers / max_calibrate), forward path applying quantization, sharded_state_dict exports quantizer state and _amax, _load_from_state_dict remaps/reshapes quantizer state for local shards, and modelopt_post_restore to validate/trigger calibration. Patch contains a duplicated class block. Also adds public imports (max_calibrate, QuantModule, QuantModuleRegistry, TensorQuantizer).
GPU tests for KV-cache quantization
tests/gpu/torch/quantization/plugins/test_megatron.py
Adds tests and helper processes for KV-cache quantization and sharded-state-dict restore on Megatron TEDotProductAttention (configs: FP8_KV_CFG, NVFP4_KV_CFG). Tests verify presence/enabled state of k_bmm_quantizer / v_bmm_quantizer, run forward smoke checks, and validate sharded checkpoint save/load restores quantizer states including _amax.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Trainer
  participant GPT as Megatron GPT
  participant QMod as _QuantTEDotProductAttention
  participant QZ as TensorQuantizers
  participant TE as TEDotProductAttention (base)

  Trainer->>GPT: build model (transformer_impl="modelopt")
  GPT->>QMod: instantiate -> _setup()
  QMod->>QZ: create q_bmm / k_bmm / v_bmm

  Trainer->>QMod: optional calibration
  alt calibration enabled
    QMod->>QZ: _calibrate_quantizers() / max_calibrate
  end

  Trainer->>QMod: forward(query,key,value,...)
  QMod->>QZ: quantize Q / K / V (if enabled)
  QMod->>TE: call parent forward with (quantized) Q/K/V
  TE-->>Trainer: attention output
Loading
sequenceDiagram
  autonumber
  participant Saver as Checkpoint Save
  participant QMod as _QuantTEDotProductAttention
  participant SD as sharded_state_dict
  participant Loader as Checkpoint Load

  Saver->>QMod: sharded_state_dict(prefix, offsets, metadata)
  QMod->>SD: include non-quant params, quantizer states (+ _amax)
  SD-->>Saver: checkpoint

  Loader->>QMod: _load_from_state_dict(state_dict, prefix, ...)
  QMod->>QMod: remap amax keys, reshape quantizer tensors
  Loader->>QMod: modelopt_post_restore()
  QMod->>QMod: validate/prepare quantizers, trigger calibration if needed
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I twitch my whiskers at quantized beams,
KV hops into cache in tiny dreams.
Q, K, V in tidy rows,
shards restored as buffer grows.
Carrots crunch — tests pass gleams. 🥕🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title concisely and accurately describes the primary enhancement of the pull request, which is adding support for key-value cache quantization in the mcore backend using batch-matrix-multiplication quantizers. It directly reflects the main change implemented in both the plugin code and associated tests without extraneous details. This phrasing will help reviewers and future readers quickly understand the core purpose of the PR.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kaix/kvcache_mcore

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)

588-620: Fix is_enabled checks in post‑restore to avoid unconditional calibration.

Call the method; otherwise it always evaluates truthy and may trigger unnecessary calibration.

Apply:

-            if not hasattr(self, quantizer_name) or not quantizer.is_enabled:
+            if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
                 continue
🧹 Nitpick comments (3)
modelopt/torch/quantization/plugins/__init__.py (1)

75-76: Gate the new mcore plugin on an importable package (likely megatron.core) and verify plugin file exists.

import_plugin("mcore") will only import .mcore if a top-level module named mcore exists. That’s probably not what you want. If the plugin depends on Megatron-Core, gate on "megatron.core" (or whichever module actually guarantees availability) and ensure modelopt/torch/quantization/plugins/mcore.py exists.

Proposed change:

-with import_plugin("mcore"):
+with import_plugin("megatron.core"):
     from .mcore import *

Also consider adding - :meth:\mcore<modelopt.torch.quantization.plugins.mcore>`` to the plugin list in the module docstring for discoverability.

To verify:

  • Check that plugins/mcore.py exists.
  • Confirm which import string ("megatron.core" vs "mcore") should gate plugin loading in your environments.
modelopt/torch/quantization/plugins/megatron.py (2)

533-567: Sharded state dict: consider consistency for amax handling.

You special‑case _amax by inserting directly, and use make_sharded_tensors_for_checkpoint for other quantizer tensors with empty shard axes. If future bmm quantizers introduce sharded tensors (e.g., per‑channel), add shard axis mapping here to avoid shape mismatches after TP changes. No action required now, but keep in mind for NVFP4 evolution.


568-587: Redundant amax key remapping.

amax_key and expected_amax_key are identical: f"{prefix}{quantizer_name}._amax". The rename is a no‑op.

Apply:

-        for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
-            full_prefix = f"{prefix}{quantizer_name}."
-            amax_key = f"{prefix}{quantizer_name}._amax"
-
-            # If amax is in state_dict, rename it to the format expected by TensorQuantizer
-            if amax_key in state_dict:
-                expected_amax_key = f"{full_prefix}_amax"
-                state_dict[expected_amax_key] = state_dict.pop(amax_key)
+        # amax keys already match TensorQuantizer state naming; no remap needed
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 598b9ce and 93bab27.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/__init__.py (1 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
  • initialize_for_megatron (385-393)
  • get_mcore_gpt_model (133-208)
  • sharded_state_dict_test_helper (410-457)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • is_enabled (389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (61-173)
modelopt/torch/quantization/nn/modules/quant_module.py (4)
  • QuantModule (37-96)
  • _setup (118-126)
  • _setup (163-169)
  • modelopt_post_restore (40-69)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • TensorQuantizer (60-1143)
  • is_enabled (389-391)
  • reset_amax (252-256)
  • forward (872-977)
modelopt/torch/quantization/plugins/huggingface.py (17)
  • _setup (55-58)
  • _setup (161-164)
  • _setup (239-244)
  • _setup (349-350)
  • _setup (365-369)
  • _setup (388-390)
  • _setup (427-473)
  • _setup (601-612)
  • forward (71-119)
  • forward (170-174)
  • forward (255-256)
  • forward (337-338)
  • forward (352-361)
  • forward (371-383)
  • forward (393-423)
  • forward (475-480)
  • forward (643-649)
modelopt/torch/quantization/plugins/custom.py (2)
  • modelopt_post_restore (117-174)
  • _check_unsupported_states (127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)

235-237: Note acknowledged — separating KV-cache configs is reasonable.


521-531: LGTM — targeted KV‑cache test entrypoint.


540-553: LGTM — sharded state dict test for KV‑cache quantization.

modelopt/torch/quantization/plugins/megatron.py (3)

25-25: Confirm TEDotProductAttention import path across supported Megatron‑Core/TE versions.

from megatron.core.extensions.transformer_engine import TEDotProductAttention may vary by version. If older MC/TE versions are supported, consider guarding this import (try/except) and registering conditionally.


467-476: LGTM — consistent quantizer setup for Q/K/V bmm.


520-532: LGTM — quantize Q/K/V post‑RoPE before delegating.

Comment on lines 477 to 528
def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device from parent module parameters
device = next(self.parameters()).device if self.parameters() else torch.device("cuda")

# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1

# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)

# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"

if qkv_format == "sbhd":
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16
)
else:
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16
)

# Calibrate each quantizer
quantizers = [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]

for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled:
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)

def forward(self, query, key, value, *args, **kwargs):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: device detection and is_enabled checks in calibration.

  • next(self.parameters()).device if self.parameters() is unsafe: self.parameters() returns a generator (always truthy) and will raise StopIteration if the module has no parameters.
  • quantizer.is_enabled must be called as a method.

Apply:

-        device = next(self.parameters()).device if self.parameters() else torch.device("cuda")
+        param = next(self.parameters(recurse=False), None)
+        buf = next(self.buffers(recurse=False), None)
+        if param is not None:
+            device = param.device
+        elif buf is not None:
+            device = buf.device
+        else:
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@
-        for _, quantizer in quantizers:
-            if quantizer is not None and quantizer.is_enabled:
+        for _, quantizer in quantizers:
+            if quantizer is not None and quantizer.is_enabled():
                 if not hasattr(quantizer, "_amax") or quantizer._amax is None:
                     quantizer.reset_amax()
                     max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device from parent module parameters
device = next(self.parameters()).device if self.parameters() else torch.device("cuda")
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1
# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)
# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
if qkv_format == "sbhd":
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16
)
else:
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16
)
# Calibrate each quantizer
quantizers = [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]
for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled:
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
def forward(self, query, key, value, *args, **kwargs):
def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device from parent module parameters
param = next(self.parameters(recurse=False), None)
buf = next(self.buffers(recurse=False), None)
if param is not None:
device = param.device
elif buf is not None:
device = buf.device
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1
# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)
# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
if qkv_format == "sbhd":
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16
)
else:
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16
)
# Calibrate each quantizer
quantizers = [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]
- for _, quantizer in quantizers:
for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled():
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
def forward(self, query, key, value, *args, **kwargs):
...

Comment on lines +374 to +418
def _test_kv_cache_quant_helper(config, rank, size):
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)

# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
model = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=1,
hidden_size=64,
num_attention_heads=4,
vocab_size=32,
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
).cuda()

# Create dummy input for calibration
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()

def forward_fn(model):
return megatron_prefill(model, prompt_tokens)

# Test KV cache quantization with the given config
quantized_model = mtq.quantize(model, config, forward_fn)

# Find TEDotProductAttention modules and verify they have KV cache quantizers
te_attention_found = False
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"

# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"

assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"

# Quick smoke test that forward still works
output = forward_fn(quantized_model)
assert output is not None, "Forward pass failed"

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix assertions: call TensorQuantizer.is_enabled as a method.

is_enabled is a method, not a property. As written, the assertions will always pass because a bound method is truthy.

Apply:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _test_kv_cache_quant_helper(config, rank, size):
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
model = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=1,
hidden_size=64,
num_attention_heads=4,
vocab_size=32,
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
).cuda()
# Create dummy input for calibration
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
def forward_fn(model):
return megatron_prefill(model, prompt_tokens)
# Test KV cache quantization with the given config
quantized_model = mtq.quantize(model, config, forward_fn)
# Find TEDotProductAttention modules and verify they have KV cache quantizers
te_attention_found = False
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
# Quick smoke test that forward still works
output = forward_fn(quantized_model)
assert output is not None, "Forward pass failed"
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 374 to
418, the assertions check TensorQuantizer.is_enabled as an attribute
(module.k_bmm_quantizer.is_enabled and module.v_bmm_quantizer.is_enabled) but
is_enabled is a method; change those assertions to call the method
(module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled()) so
they evaluate correctly.

Comment on lines +420 to +513
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
"""Helper for testing KV cache quantization with sharded state dict save/load."""
# Disable output_layer quantization (same as other sharded state dict tests)
config["quant_cfg"]["*output_layer*"] = {"enable": False}

initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)

# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
model_ref = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2, # At least 2 layers to test multiple attention modules
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
).cuda()

model_test = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2,
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt",
).cuda()

prompt_tokens = torch.randint(
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
).cuda()

def forward_fn(model):
return megatron_prefill(model, prompt_tokens)

# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)

# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"

assert kv_quantizers_found, "No KV cache quantizers found in quantized model"

# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
model_ref,
model_test,
forward_fn,
meta_device=False,
version=None,
)

# Verify KV cache quantizers are restored correctly in model_test
for (name_ref, module_ref), (name_test, module_test) in zip(
model_ref.named_modules(), model_test.named_modules()
):
if hasattr(module_ref, "k_bmm_quantizer"):
assert hasattr(module_test, "k_bmm_quantizer"), (
f"K quantizer missing after restore in {name_test}"
)
assert hasattr(module_test, "v_bmm_quantizer"), (
f"V quantizer missing after restore in {name_test}"
)

# Check that quantizer states match
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
f"K quantizer _amax missing in {name_test}"
)
if module_ref.k_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
), f"K quantizer _amax mismatch in {name_test}"

if hasattr(module_ref.v_bmm_quantizer, "_amax"):
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
f"V quantizer _amax missing in {name_test}"
)
if module_ref.v_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
), f"V quantizer _amax mismatch in {name_test}"


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix is_enabled checks in sharded state dict helper.

Same issue: call is_enabled().

Apply:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
"""Helper for testing KV cache quantization with sharded state dict save/load."""
# Disable output_layer quantization (same as other sharded state dict tests)
config["quant_cfg"]["*output_layer*"] = {"enable": False}
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
model_ref = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2, # At least 2 layers to test multiple attention modules
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
).cuda()
model_test = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2,
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt",
).cuda()
prompt_tokens = torch.randint(
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
).cuda()
def forward_fn(model):
return megatron_prefill(model, prompt_tokens)
# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)
# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)
# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
model_ref,
model_test,
forward_fn,
meta_device=False,
version=None,
)
# Verify KV cache quantizers are restored correctly in model_test
for (name_ref, module_ref), (name_test, module_test) in zip(
model_ref.named_modules(), model_test.named_modules()
):
if hasattr(module_ref, "k_bmm_quantizer"):
assert hasattr(module_test, "k_bmm_quantizer"), (
f"K quantizer missing after restore in {name_test}"
)
assert hasattr(module_test, "v_bmm_quantizer"), (
f"V quantizer missing after restore in {name_test}"
)
# Check that quantizer states match
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
f"K quantizer _amax missing in {name_test}"
)
if module_ref.k_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
), f"K quantizer _amax mismatch in {name_test}"
if hasattr(module_ref.v_bmm_quantizer, "_amax"):
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
f"V quantizer _amax missing in {name_test}"
)
if module_ref.v_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
), f"V quantizer _amax mismatch in {name_test}"
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 420-513,
the assertions that verify KV cache quantizers use the attribute access
.is_enabled instead of calling the method; update those assertions to call
module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled() so
the checks invoke the method rather than reference a non-callable attribute.

coderabbitai bot added a commit that referenced this pull request Sep 25, 2025
Docstrings generation was requested by @kaix-nv.

* #375 (comment)

The following files were modified:

* `modelopt/torch/quantization/plugins/megatron.py`
* `tests/gpu/torch/quantization/plugins/test_megatron.py`
Copy link

coderabbitai bot commented Sep 25, 2025

Note

Generated docstrings for this pull request at #376

Copy link

coderabbitai bot commented Sep 25, 2025

Note

Unit test generation is an Early Access feature. Expect some limitations and changes as we gather feedback and continue to improve it.


Generating unit tests... This may take up to 20 minutes.

Copy link

coderabbitai bot commented Sep 25, 2025

UTG Post-Process Complete

No new issues were detected in the generated code and all check runs have completed. The unit test generation process has completed successfully.

Copy link

coderabbitai bot commented Sep 25, 2025

Creating a PR to put the unit tests in...

The changes have been created in this pull request: View PR

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 93bab27 and 5fc56fe.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/__init__.py (1 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/init.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.

Applied to files:

  • tests/gpu/torch/quantization/plugins/test_megatron.py
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
  • initialize_for_megatron (385-393)
  • get_mcore_gpt_model (133-208)
  • sharded_state_dict_test_helper (410-457)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • is_enabled (389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (61-173)
modelopt/torch/quantization/nn/modules/quant_module.py (4)
  • QuantModule (37-96)
  • _setup (118-126)
  • _setup (163-169)
  • modelopt_post_restore (40-69)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • TensorQuantizer (60-1143)
  • is_enabled (389-391)
  • reset_amax (252-256)
  • forward (872-977)
modelopt/torch/quantization/plugins/huggingface.py (17)
  • _setup (55-58)
  • _setup (161-164)
  • _setup (239-244)
  • _setup (349-350)
  • _setup (365-369)
  • _setup (388-390)
  • _setup (427-473)
  • _setup (601-612)
  • forward (71-119)
  • forward (170-174)
  • forward (255-256)
  • forward (337-338)
  • forward (352-361)
  • forward (371-383)
  • forward (393-423)
  • forward (475-480)
  • forward (643-649)
modelopt/torch/quantization/plugins/custom.py (2)
  • modelopt_post_restore (117-174)
  • _check_unsupported_states (127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)

409-411: Fix assertions: call TensorQuantizer.is_enabled()

Line 410 still treats is_enabled as an attribute, so the assertion always passes because a bound method is truthy. Please invoke it.

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"

467-468: Fix assertions: call TensorQuantizer.is_enabled()

Line 467 repeats the same issue—is_enabled must be invoked. Without the call these checks never fail even when the quantizers are disabled.

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
modelopt/torch/quantization/plugins/megatron.py (1)

483-488: Calibrator device/dtype inference must handle param/buffer-less modules

In _calibrate_quantizers (Lines 486-488), param = next(iter(self.parameters()), None) falls back to torch.device("cuda") and torch.float16 whenever the module has no parameters. Many wrapper modules expose only buffers (or can be instantiated on CPU/meta first), so this default hits real code paths and will crash on hosts without CUDA. Please probe both parameters and buffers, and only prefer CUDA if it is actually available.

-        param = next(iter(self.parameters()), None)
-        device = param.device if param is not None else torch.device("cuda")
-        dtype = param.dtype if param is not None else torch.float16
+        param = next(self.parameters(recurse=False), None)
+        buf = next(self.buffers(recurse=False), None)
+        if param is not None:
+            device, dtype = param.device, param.dtype
+        elif buf is not None:
+            device, dtype = buf.device, buf.dtype
+        else:
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+            dtype = torch.float16 if device.type == "cuda" else torch.float32

This matches the safe pattern already used elsewhere in the plugin and avoids hard-crashing on CPU/meta setups.

import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.extensions.transformer_engine import TEDotProductAttention
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard TEDotProductAttention import when TE is unavailable

Line 25 introduces an unconditional import of TEDotProductAttention. In environments where Transformer Engine is not installed (which is a supported configuration for the rest of this plugin), this raises ModuleNotFoundError during import and breaks all Megatron quantization paths that previously worked without TE. Please wrap the import in a try/except and only register _QuantTEDotProductAttention when the symbol is available.

-from megatron.core.extensions.transformer_engine import TEDotProductAttention
+try:
+    from megatron.core.extensions.transformer_engine import TEDotProductAttention
+except ImportError:
+    TEDotProductAttention = None

And guard the registration/class definition behind if TEDotProductAttention is not None: so the module remains usable without TE.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from megatron.core.extensions.transformer_engine import TEDotProductAttention
try:
from megatron.core.extensions.transformer_engine import TEDotProductAttention
except ImportError:
TEDotProductAttention = None
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around line 25, the
unconditional import of TEDotProductAttention will raise ModuleNotFoundError
when Transformer Engine (TE) is not installed; wrap the import in a try/except
that sets TEDotProductAttention = None on ImportError, then only define and
register the _QuantTEDotProductAttention class (and any related registration
calls) inside an if TEDotProductAttention is not None: block so the module
remains importable and Megatron quantization works without TE.

Comment on lines +423 to +424
config["quant_cfg"]["*output_layer*"] = {"enable": False}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not mutate shared config constants in-place

Line 423 mutates the shared config object (mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG) by updating quant_cfg. Because those constants are reused across parametrized tests (and potentially by library consumers), this causes cross-test contamination and unexpected behavior later in the suite. Make a copy before modifying.

-def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+    config = copy.deepcopy(config)
     """Helper for testing KV cache quantization with sharded state dict save/load."""
     # Disable output_layer quantization (same as other sharded state dict tests)
-    config["quant_cfg"]["*output_layer*"] = {"enable": False}
+    config["quant_cfg"]["*output_layer*"] = {"enable": False}
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 423-424,
the test mutates the shared config constant by assigning
config["quant_cfg"]["*output_layer*"] = {"enable": False}; instead, avoid
in-place mutation by making a copy (use copy.deepcopy(config) or dict deepcopy)
into a local variable and mutate that copy (or create a shallow copy of
quant_cfg and update it) so the original mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG
remain unchanged; replace subsequent uses in the test with the copied config.

@kaix-nv kaix-nv requested a review from jingyu-ml September 25, 2025 22:30
Copy link
Contributor

@jingyu-ml jingyu-ml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants