-
Notifications
You must be signed in to change notification settings - Fork 168
Support kv cache quantization for mcore using bmm_quantizers #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kai Xu <[email protected]>
WalkthroughAdds mcore plugin import; introduces a QuantModule Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Trainer
participant GPT as Megatron GPT
participant QMod as _QuantTEDotProductAttention
participant QZ as TensorQuantizers
participant TE as TEDotProductAttention (base)
Trainer->>GPT: build model (transformer_impl="modelopt")
GPT->>QMod: instantiate -> _setup()
QMod->>QZ: create q_bmm / k_bmm / v_bmm
Trainer->>QMod: optional calibration
alt calibration enabled
QMod->>QZ: _calibrate_quantizers() / max_calibrate
end
Trainer->>QMod: forward(query,key,value,...)
QMod->>QZ: quantize Q / K / V (if enabled)
QMod->>TE: call parent forward with (quantized) Q/K/V
TE-->>Trainer: attention output
sequenceDiagram
autonumber
participant Saver as Checkpoint Save
participant QMod as _QuantTEDotProductAttention
participant SD as sharded_state_dict
participant Loader as Checkpoint Load
Saver->>QMod: sharded_state_dict(prefix, offsets, metadata)
QMod->>SD: include non-quant params, quantizer states (+ _amax)
SD-->>Saver: checkpoint
Loader->>QMod: _load_from_state_dict(state_dict, prefix, ...)
QMod->>QMod: remap amax keys, reshape quantizer tensors
Loader->>QMod: modelopt_post_restore()
QMod->>QMod: validate/prepare quantizers, trigger calibration if needed
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)
588-620
: Fix is_enabled checks in post‑restore to avoid unconditional calibration.Call the method; otherwise it always evaluates truthy and may trigger unnecessary calibration.
Apply:
- if not hasattr(self, quantizer_name) or not quantizer.is_enabled: + if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): continue
🧹 Nitpick comments (3)
modelopt/torch/quantization/plugins/__init__.py (1)
75-76
: Gate the new mcore plugin on an importable package (likely megatron.core) and verify plugin file exists.
import_plugin("mcore")
will only import.mcore
if a top-level module namedmcore
exists. That’s probably not what you want. If the plugin depends on Megatron-Core, gate on"megatron.core"
(or whichever module actually guarantees availability) and ensuremodelopt/torch/quantization/plugins/mcore.py
exists.Proposed change:
-with import_plugin("mcore"): +with import_plugin("megatron.core"): from .mcore import *Also consider adding
- :meth:\
mcore<modelopt.torch.quantization.plugins.mcore>`` to the plugin list in the module docstring for discoverability.To verify:
- Check that
plugins/mcore.py
exists.- Confirm which import string (
"megatron.core"
vs"mcore"
) should gate plugin loading in your environments.modelopt/torch/quantization/plugins/megatron.py (2)
533-567
: Sharded state dict: consider consistency for amax handling.You special‑case
_amax
by inserting directly, and usemake_sharded_tensors_for_checkpoint
for other quantizer tensors with empty shard axes. If future bmm quantizers introduce sharded tensors (e.g., per‑channel), add shard axis mapping here to avoid shape mismatches after TP changes. No action required now, but keep in mind for NVFP4 evolution.
568-587
: Redundant amax key remapping.
amax_key
andexpected_amax_key
are identical:f"{prefix}{quantizer_name}._amax"
. The rename is a no‑op.Apply:
- for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: - full_prefix = f"{prefix}{quantizer_name}." - amax_key = f"{prefix}{quantizer_name}._amax" - - # If amax is in state_dict, rename it to the format expected by TensorQuantizer - if amax_key in state_dict: - expected_amax_key = f"{full_prefix}_amax" - state_dict[expected_amax_key] = state_dict.pop(amax_key) + # amax keys already match TensorQuantizer state naming; no remap needed
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/plugins/__init__.py
(1 hunks)modelopt/torch/quantization/plugins/megatron.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
initialize_for_megatron
(385-393)get_mcore_gpt_model
(133-208)sharded_state_dict_test_helper
(410-457)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill
(41-130)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled
(389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
max_calibrate
(61-173)modelopt/torch/quantization/nn/modules/quant_module.py (4)
QuantModule
(37-96)_setup
(118-126)_setup
(163-169)modelopt_post_restore
(40-69)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
TensorQuantizer
(60-1143)is_enabled
(389-391)reset_amax
(252-256)forward
(872-977)modelopt/torch/quantization/plugins/huggingface.py (17)
_setup
(55-58)_setup
(161-164)_setup
(239-244)_setup
(349-350)_setup
(365-369)_setup
(388-390)_setup
(427-473)_setup
(601-612)forward
(71-119)forward
(170-174)forward
(255-256)forward
(337-338)forward
(352-361)forward
(371-383)forward
(393-423)forward
(475-480)forward
(643-649)modelopt/torch/quantization/plugins/custom.py (2)
modelopt_post_restore
(117-174)_check_unsupported_states
(127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
235-237
: Note acknowledged — separating KV-cache configs is reasonable.
521-531
: LGTM — targeted KV‑cache test entrypoint.
540-553
: LGTM — sharded state dict test for KV‑cache quantization.modelopt/torch/quantization/plugins/megatron.py (3)
25-25
: Confirm TEDotProductAttention import path across supported Megatron‑Core/TE versions.
from megatron.core.extensions.transformer_engine import TEDotProductAttention
may vary by version. If older MC/TE versions are supported, consider guarding this import (try/except) and registering conditionally.
467-476
: LGTM — consistent quantizer setup for Q/K/V bmm.
520-532
: LGTM — quantize Q/K/V post‑RoPE before delegating.
def _calibrate_quantizers(self): | ||
"""Calibrate quantizers with minimal dummy tensors.""" | ||
# Get device from parent module parameters | ||
device = next(self.parameters()).device if self.parameters() else torch.device("cuda") | ||
|
||
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion | ||
batch_size = 1 | ||
seq_len = 1 | ||
|
||
# Get dimensions from config | ||
num_heads = self.config.num_attention_heads | ||
head_dim = ( | ||
self.config.kv_channels | ||
if hasattr(self.config, "kv_channels") | ||
else self.config.hidden_size // num_heads | ||
) | ||
|
||
# Determine tensor format (default to sbhd if not specified) | ||
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) | ||
qkv_format = "bshd" if apply_rope_fusion else "sbhd" | ||
|
||
if qkv_format == "sbhd": | ||
dummy_tensor = torch.randn( | ||
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16 | ||
) | ||
else: | ||
dummy_tensor = torch.randn( | ||
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16 | ||
) | ||
|
||
# Calibrate each quantizer | ||
quantizers = [ | ||
("q_bmm_quantizer", self.q_bmm_quantizer), | ||
("k_bmm_quantizer", self.k_bmm_quantizer), | ||
("v_bmm_quantizer", self.v_bmm_quantizer), | ||
] | ||
|
||
for _, quantizer in quantizers: | ||
if quantizer is not None and quantizer.is_enabled: | ||
if not hasattr(quantizer, "_amax") or quantizer._amax is None: | ||
quantizer.reset_amax() | ||
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) | ||
|
||
def forward(self, query, key, value, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: device detection and is_enabled checks in calibration.
next(self.parameters()).device if self.parameters()
is unsafe:self.parameters()
returns a generator (always truthy) and will raise StopIteration if the module has no parameters.quantizer.is_enabled
must be called as a method.
Apply:
- device = next(self.parameters()).device if self.parameters() else torch.device("cuda")
+ param = next(self.parameters(recurse=False), None)
+ buf = next(self.buffers(recurse=False), None)
+ if param is not None:
+ device = param.device
+ elif buf is not None:
+ device = buf.device
+ else:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@
- for _, quantizer in quantizers:
- if quantizer is not None and quantizer.is_enabled:
+ for _, quantizer in quantizers:
+ if quantizer is not None and quantizer.is_enabled():
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _calibrate_quantizers(self): | |
"""Calibrate quantizers with minimal dummy tensors.""" | |
# Get device from parent module parameters | |
device = next(self.parameters()).device if self.parameters() else torch.device("cuda") | |
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion | |
batch_size = 1 | |
seq_len = 1 | |
# Get dimensions from config | |
num_heads = self.config.num_attention_heads | |
head_dim = ( | |
self.config.kv_channels | |
if hasattr(self.config, "kv_channels") | |
else self.config.hidden_size // num_heads | |
) | |
# Determine tensor format (default to sbhd if not specified) | |
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) | |
qkv_format = "bshd" if apply_rope_fusion else "sbhd" | |
if qkv_format == "sbhd": | |
dummy_tensor = torch.randn( | |
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16 | |
) | |
else: | |
dummy_tensor = torch.randn( | |
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16 | |
) | |
# Calibrate each quantizer | |
quantizers = [ | |
("q_bmm_quantizer", self.q_bmm_quantizer), | |
("k_bmm_quantizer", self.k_bmm_quantizer), | |
("v_bmm_quantizer", self.v_bmm_quantizer), | |
] | |
for _, quantizer in quantizers: | |
if quantizer is not None and quantizer.is_enabled: | |
if not hasattr(quantizer, "_amax") or quantizer._amax is None: | |
quantizer.reset_amax() | |
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) | |
def forward(self, query, key, value, *args, **kwargs): | |
def _calibrate_quantizers(self): | |
"""Calibrate quantizers with minimal dummy tensors.""" | |
# Get device from parent module parameters | |
param = next(self.parameters(recurse=False), None) | |
buf = next(self.buffers(recurse=False), None) | |
if param is not None: | |
device = param.device | |
elif buf is not None: | |
device = buf.device | |
else: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion | |
batch_size = 1 | |
seq_len = 1 | |
# Get dimensions from config | |
num_heads = self.config.num_attention_heads | |
head_dim = ( | |
self.config.kv_channels | |
if hasattr(self.config, "kv_channels") | |
else self.config.hidden_size // num_heads | |
) | |
# Determine tensor format (default to sbhd if not specified) | |
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) | |
qkv_format = "bshd" if apply_rope_fusion else "sbhd" | |
if qkv_format == "sbhd": | |
dummy_tensor = torch.randn( | |
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16 | |
) | |
else: | |
dummy_tensor = torch.randn( | |
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16 | |
) | |
# Calibrate each quantizer | |
quantizers = [ | |
("q_bmm_quantizer", self.q_bmm_quantizer), | |
("k_bmm_quantizer", self.k_bmm_quantizer), | |
("v_bmm_quantizer", self.v_bmm_quantizer), | |
] | |
- for _, quantizer in quantizers: | |
for _, quantizer in quantizers: | |
if quantizer is not None and quantizer.is_enabled(): | |
if not hasattr(quantizer, "_amax") or quantizer._amax is None: | |
quantizer.reset_amax() | |
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) | |
def forward(self, query, key, value, *args, **kwargs): | |
... |
def _test_kv_cache_quant_helper(config, rank, size): | ||
"""Helper function for testing KV cache quantization with TEDotProductAttention.""" | ||
initialize_for_megatron( | ||
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | ||
) | ||
|
||
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention | ||
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention | ||
model = get_mcore_gpt_model( | ||
tensor_model_parallel_size=size, | ||
num_layers=1, | ||
hidden_size=64, | ||
num_attention_heads=4, | ||
vocab_size=32, | ||
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec | ||
).cuda() | ||
|
||
# Create dummy input for calibration | ||
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() | ||
|
||
def forward_fn(model): | ||
return megatron_prefill(model, prompt_tokens) | ||
|
||
# Test KV cache quantization with the given config | ||
quantized_model = mtq.quantize(model, config, forward_fn) | ||
|
||
# Find TEDotProductAttention modules and verify they have KV cache quantizers | ||
te_attention_found = False | ||
for name, module in quantized_model.named_modules(): | ||
# Check if this is a quantized TEDotProductAttention | ||
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): | ||
te_attention_found = True | ||
# Verify all expected quantizers exist | ||
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" | ||
|
||
# Verify K and V quantizers are enabled (main purpose of KV cache configs) | ||
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | ||
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | ||
|
||
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" | ||
|
||
# Quick smoke test that forward still works | ||
output = forward_fn(quantized_model) | ||
assert output is not None, "Forward pass failed" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix assertions: call TensorQuantizer.is_enabled as a method.
is_enabled
is a method, not a property. As written, the assertions will always pass because a bound method is truthy.
Apply:
- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
- assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+ assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+ assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _test_kv_cache_quant_helper(config, rank, size): | |
"""Helper function for testing KV cache quantization with TEDotProductAttention.""" | |
initialize_for_megatron( | |
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | |
) | |
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention | |
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention | |
model = get_mcore_gpt_model( | |
tensor_model_parallel_size=size, | |
num_layers=1, | |
hidden_size=64, | |
num_attention_heads=4, | |
vocab_size=32, | |
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec | |
).cuda() | |
# Create dummy input for calibration | |
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() | |
def forward_fn(model): | |
return megatron_prefill(model, prompt_tokens) | |
# Test KV cache quantization with the given config | |
quantized_model = mtq.quantize(model, config, forward_fn) | |
# Find TEDotProductAttention modules and verify they have KV cache quantizers | |
te_attention_found = False | |
for name, module in quantized_model.named_modules(): | |
# Check if this is a quantized TEDotProductAttention | |
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): | |
te_attention_found = True | |
# Verify all expected quantizers exist | |
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" | |
# Verify K and V quantizers are enabled (main purpose of KV cache configs) | |
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | |
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | |
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" | |
# Quick smoke test that forward still works | |
output = forward_fn(quantized_model) | |
assert output is not None, "Forward pass failed" | |
for name, module in quantized_model.named_modules(): | |
# Check if this is a quantized TEDotProductAttention | |
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): | |
te_attention_found = True | |
# Verify all expected quantizers exist | |
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" | |
# Verify K and V quantizers are enabled (main purpose of KV cache configs) | |
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" | |
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}" |
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 374 to
418, the assertions check TensorQuantizer.is_enabled as an attribute
(module.k_bmm_quantizer.is_enabled and module.v_bmm_quantizer.is_enabled) but
is_enabled is a method; change those assertions to call the method
(module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled()) so
they evaluate correctly.
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): | ||
"""Helper for testing KV cache quantization with sharded state dict save/load.""" | ||
# Disable output_layer quantization (same as other sharded state dict tests) | ||
config["quant_cfg"]["*output_layer*"] = {"enable": False} | ||
|
||
initialize_for_megatron( | ||
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | ||
) | ||
|
||
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt") | ||
model_ref = get_mcore_gpt_model( | ||
tensor_model_parallel_size=size, | ||
num_layers=2, # At least 2 layers to test multiple attention modules | ||
hidden_size=64, | ||
num_attention_heads=4, | ||
vocab_size=64, | ||
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention | ||
).cuda() | ||
|
||
model_test = get_mcore_gpt_model( | ||
tensor_model_parallel_size=size, | ||
num_layers=2, | ||
hidden_size=64, | ||
num_attention_heads=4, | ||
vocab_size=64, | ||
transformer_impl="modelopt", | ||
).cuda() | ||
|
||
prompt_tokens = torch.randint( | ||
0, model_ref.vocab_size, (2, model_ref.max_sequence_length) | ||
).cuda() | ||
|
||
def forward_fn(model): | ||
return megatron_prefill(model, prompt_tokens) | ||
|
||
# Quantize the reference model | ||
model_ref = mtq.quantize(model_ref, config, forward_fn) | ||
|
||
# CRITICAL: model_test must also be quantized with the same config | ||
# Otherwise it won't have the KV cache quantizer keys when loading state dict | ||
model_test = mtq.quantize(model_test, config, forward_fn) | ||
|
||
# Verify KV cache quantizers were created | ||
kv_quantizers_found = False | ||
for name, module in model_ref.named_modules(): | ||
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | ||
kv_quantizers_found = True | ||
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | ||
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | ||
|
||
assert kv_quantizers_found, "No KV cache quantizers found in quantized model" | ||
|
||
# Test sharded state dict save/load | ||
sharded_state_dict_test_helper( | ||
tmp_path, | ||
model_ref, | ||
model_test, | ||
forward_fn, | ||
meta_device=False, | ||
version=None, | ||
) | ||
|
||
# Verify KV cache quantizers are restored correctly in model_test | ||
for (name_ref, module_ref), (name_test, module_test) in zip( | ||
model_ref.named_modules(), model_test.named_modules() | ||
): | ||
if hasattr(module_ref, "k_bmm_quantizer"): | ||
assert hasattr(module_test, "k_bmm_quantizer"), ( | ||
f"K quantizer missing after restore in {name_test}" | ||
) | ||
assert hasattr(module_test, "v_bmm_quantizer"), ( | ||
f"V quantizer missing after restore in {name_test}" | ||
) | ||
|
||
# Check that quantizer states match | ||
if hasattr(module_ref.k_bmm_quantizer, "_amax"): | ||
assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( | ||
f"K quantizer _amax missing in {name_test}" | ||
) | ||
if module_ref.k_bmm_quantizer._amax is not None: | ||
assert torch.allclose( | ||
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax | ||
), f"K quantizer _amax mismatch in {name_test}" | ||
|
||
if hasattr(module_ref.v_bmm_quantizer, "_amax"): | ||
assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( | ||
f"V quantizer _amax missing in {name_test}" | ||
) | ||
if module_ref.v_bmm_quantizer._amax is not None: | ||
assert torch.allclose( | ||
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax | ||
), f"V quantizer _amax mismatch in {name_test}" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix is_enabled checks in sharded state dict helper.
Same issue: call is_enabled()
.
Apply:
- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
- assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+ assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+ assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): | |
"""Helper for testing KV cache quantization with sharded state dict save/load.""" | |
# Disable output_layer quantization (same as other sharded state dict tests) | |
config["quant_cfg"]["*output_layer*"] = {"enable": False} | |
initialize_for_megatron( | |
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | |
) | |
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt") | |
model_ref = get_mcore_gpt_model( | |
tensor_model_parallel_size=size, | |
num_layers=2, # At least 2 layers to test multiple attention modules | |
hidden_size=64, | |
num_attention_heads=4, | |
vocab_size=64, | |
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention | |
).cuda() | |
model_test = get_mcore_gpt_model( | |
tensor_model_parallel_size=size, | |
num_layers=2, | |
hidden_size=64, | |
num_attention_heads=4, | |
vocab_size=64, | |
transformer_impl="modelopt", | |
).cuda() | |
prompt_tokens = torch.randint( | |
0, model_ref.vocab_size, (2, model_ref.max_sequence_length) | |
).cuda() | |
def forward_fn(model): | |
return megatron_prefill(model, prompt_tokens) | |
# Quantize the reference model | |
model_ref = mtq.quantize(model_ref, config, forward_fn) | |
# CRITICAL: model_test must also be quantized with the same config | |
# Otherwise it won't have the KV cache quantizer keys when loading state dict | |
model_test = mtq.quantize(model_test, config, forward_fn) | |
# Verify KV cache quantizers were created | |
kv_quantizers_found = False | |
for name, module in model_ref.named_modules(): | |
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | |
kv_quantizers_found = True | |
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | |
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | |
assert kv_quantizers_found, "No KV cache quantizers found in quantized model" | |
# Test sharded state dict save/load | |
sharded_state_dict_test_helper( | |
tmp_path, | |
model_ref, | |
model_test, | |
forward_fn, | |
meta_device=False, | |
version=None, | |
) | |
# Verify KV cache quantizers are restored correctly in model_test | |
for (name_ref, module_ref), (name_test, module_test) in zip( | |
model_ref.named_modules(), model_test.named_modules() | |
): | |
if hasattr(module_ref, "k_bmm_quantizer"): | |
assert hasattr(module_test, "k_bmm_quantizer"), ( | |
f"K quantizer missing after restore in {name_test}" | |
) | |
assert hasattr(module_test, "v_bmm_quantizer"), ( | |
f"V quantizer missing after restore in {name_test}" | |
) | |
# Check that quantizer states match | |
if hasattr(module_ref.k_bmm_quantizer, "_amax"): | |
assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( | |
f"K quantizer _amax missing in {name_test}" | |
) | |
if module_ref.k_bmm_quantizer._amax is not None: | |
assert torch.allclose( | |
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax | |
), f"K quantizer _amax mismatch in {name_test}" | |
if hasattr(module_ref.v_bmm_quantizer, "_amax"): | |
assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( | |
f"V quantizer _amax missing in {name_test}" | |
) | |
if module_ref.v_bmm_quantizer._amax is not None: | |
assert torch.allclose( | |
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax | |
), f"V quantizer _amax mismatch in {name_test}" | |
for name, module in model_ref.named_modules(): | |
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | |
kv_quantizers_found = True | |
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" | |
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}" |
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 420-513,
the assertions that verify KV cache quantizers use the attribute access
.is_enabled instead of calling the method; update those assertions to call
module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled() so
the checks invoke the method rather than reference a non-callable attribute.
Docstrings generation was requested by @kaix-nv. * #375 (comment) The following files were modified: * `modelopt/torch/quantization/plugins/megatron.py` * `tests/gpu/torch/quantization/plugins/test_megatron.py`
Note Generated docstrings for this pull request at #376 |
Note Unit test generation is an Early Access feature. Expect some limitations and changes as we gather feedback and continue to improve it. Generating unit tests... This may take up to 20 minutes. |
✅ UTG Post-Process Complete No new issues were detected in the generated code and all check runs have completed. The unit test generation process has completed successfully. |
Creating a PR to put the unit tests in... The changes have been created in this pull request: View PR |
Signed-off-by: Kai Xu <[email protected]>
93bab27
to
5fc56fe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/plugins/__init__.py
(1 hunks)modelopt/torch/quantization/plugins/megatron.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/plugins/init.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.
Applied to files:
tests/gpu/torch/quantization/plugins/test_megatron.py
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
initialize_for_megatron
(385-393)get_mcore_gpt_model
(133-208)sharded_state_dict_test_helper
(410-457)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill
(41-130)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled
(389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
max_calibrate
(61-173)modelopt/torch/quantization/nn/modules/quant_module.py (4)
QuantModule
(37-96)_setup
(118-126)_setup
(163-169)modelopt_post_restore
(40-69)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
TensorQuantizer
(60-1143)is_enabled
(389-391)reset_amax
(252-256)forward
(872-977)modelopt/torch/quantization/plugins/huggingface.py (17)
_setup
(55-58)_setup
(161-164)_setup
(239-244)_setup
(349-350)_setup
(365-369)_setup
(388-390)_setup
(427-473)_setup
(601-612)forward
(71-119)forward
(170-174)forward
(255-256)forward
(337-338)forward
(352-361)forward
(371-383)forward
(393-423)forward
(475-480)forward
(643-649)modelopt/torch/quantization/plugins/custom.py (2)
modelopt_post_restore
(117-174)_check_unsupported_states
(127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)
409-411
: Fix assertions: call TensorQuantizer.is_enabled()Line 410 still treats
is_enabled
as an attribute, so the assertion always passes because a bound method is truthy. Please invoke it.- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
467-468
: Fix assertions: call TensorQuantizer.is_enabled()Line 467 repeats the same issue—
is_enabled
must be invoked. Without the call these checks never fail even when the quantizers are disabled.- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"modelopt/torch/quantization/plugins/megatron.py (1)
483-488
: Calibrator device/dtype inference must handle param/buffer-less modulesIn
_calibrate_quantizers
(Lines 486-488),param = next(iter(self.parameters()), None)
falls back totorch.device("cuda")
andtorch.float16
whenever the module has no parameters. Many wrapper modules expose only buffers (or can be instantiated on CPU/meta first), so this default hits real code paths and will crash on hosts without CUDA. Please probe both parameters and buffers, and only prefer CUDA if it is actually available.- param = next(iter(self.parameters()), None) - device = param.device if param is not None else torch.device("cuda") - dtype = param.dtype if param is not None else torch.float16 + param = next(self.parameters(recurse=False), None) + buf = next(self.buffers(recurse=False), None) + if param is not None: + device, dtype = param.device, param.dtype + elif buf is not None: + device, dtype = buf.device, buf.dtype + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if device.type == "cuda" else torch.float32This matches the safe pattern already used elsewhere in the plugin and avoids hard-crashing on CPU/meta setups.
import megatron.core.tensor_parallel.layers as megatron_parallel | ||
import megatron.core.transformer.mlp as megatron_mlp | ||
import torch | ||
from megatron.core.extensions.transformer_engine import TEDotProductAttention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard TEDotProductAttention import when TE is unavailable
Line 25 introduces an unconditional import of TEDotProductAttention
. In environments where Transformer Engine is not installed (which is a supported configuration for the rest of this plugin), this raises ModuleNotFoundError
during import and breaks all Megatron quantization paths that previously worked without TE. Please wrap the import in a try/except
and only register _QuantTEDotProductAttention
when the symbol is available.
-from megatron.core.extensions.transformer_engine import TEDotProductAttention
+try:
+ from megatron.core.extensions.transformer_engine import TEDotProductAttention
+except ImportError:
+ TEDotProductAttention = None
And guard the registration/class definition behind if TEDotProductAttention is not None:
so the module remains usable without TE.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from megatron.core.extensions.transformer_engine import TEDotProductAttention | |
try: | |
from megatron.core.extensions.transformer_engine import TEDotProductAttention | |
except ImportError: | |
TEDotProductAttention = None |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around line 25, the
unconditional import of TEDotProductAttention will raise ModuleNotFoundError
when Transformer Engine (TE) is not installed; wrap the import in a try/except
that sets TEDotProductAttention = None on ImportError, then only define and
register the _QuantTEDotProductAttention class (and any related registration
calls) inside an if TEDotProductAttention is not None: block so the module
remains importable and Megatron quantization works without TE.
config["quant_cfg"]["*output_layer*"] = {"enable": False} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not mutate shared config constants in-place
Line 423 mutates the shared config
object (mtq.FP8_KV_CFG
/ mtq.NVFP4_KV_CFG
) by updating quant_cfg
. Because those constants are reused across parametrized tests (and potentially by library consumers), this causes cross-test contamination and unexpected behavior later in the suite. Make a copy before modifying.
-def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+ config = copy.deepcopy(config)
"""Helper for testing KV cache quantization with sharded state dict save/load."""
# Disable output_layer quantization (same as other sharded state dict tests)
- config["quant_cfg"]["*output_layer*"] = {"enable": False}
+ config["quant_cfg"]["*output_layer*"] = {"enable": False}
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 423-424,
the test mutates the shared config constant by assigning
config["quant_cfg"]["*output_layer*"] = {"enable": False}; instead, avoid
in-place mutation by making a copy (use copy.deepcopy(config) or dict deepcopy)
into a local variable and mutate that copy (or create a shallow copy of
quant_cfg and update it) so the original mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG
remain unchanged; replace subsequent uses in the test with the copied config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
Type of change: ?
New feature
Overview: ?
Usage
Testing
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config1]
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config0]
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config0] PASSED
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config1] PASSED
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Tests