|
26 | 26 | from megatron.core.transformer import MegatronModule
|
27 | 27 | from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
|
28 | 28 | from megatron.core.utils import get_tensor_model_parallel_group_if_none
|
| 29 | +from megatron.core.extensions.transformer_engine import TEDotProductAttention |
29 | 30 |
|
30 | 31 | from modelopt.torch.opt.plugins.megatron import (
|
31 | 32 | _MegatronMLP,
|
32 | 33 | register_modelopt_extra_state_callbacks,
|
33 | 34 | )
|
34 | 35 | from modelopt.torch.utils.distributed import ParallelState
|
35 | 36 |
|
36 |
| -from ..nn import QuantModuleRegistry, TensorQuantizer |
| 37 | +from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer |
37 | 38 | from ..nn.modules.quant_linear import RealQuantLinear
|
38 | 39 | from ..qtensor import QTensorWrapper
|
39 | 40 | from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
|
| 41 | +from ..model_calib import max_calibrate |
40 | 42 |
|
41 | 43 | __all__ = []
|
42 | 44 |
|
@@ -460,3 +462,149 @@ class _RealQuantMegatronRowParallelLinear(
|
460 | 462 |
|
461 | 463 | def forward(self, input, *args, **kwargs):
|
462 | 464 | return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
|
| 465 | + |
| 466 | + |
| 467 | +@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"}) |
| 468 | +class _QuantTEDotProductAttention(QuantModule): |
| 469 | + """Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.""" |
| 470 | + |
| 471 | + def _setup(self): |
| 472 | + """Initialize quantizers for Q, K, V tensors.""" |
| 473 | + self.q_bmm_quantizer = TensorQuantizer() |
| 474 | + self.k_bmm_quantizer = TensorQuantizer() |
| 475 | + self.v_bmm_quantizer = TensorQuantizer() |
| 476 | + |
| 477 | + def _calibrate_quantizers(self): |
| 478 | + """Calibrate quantizers with minimal dummy tensors.""" |
| 479 | + # Get device from parent module parameters |
| 480 | + device = next(self.parameters()).device if self.parameters() else torch.device('cuda') |
| 481 | + |
| 482 | + # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion |
| 483 | + batch_size = 1 |
| 484 | + seq_len = 1 |
| 485 | + |
| 486 | + # Get dimensions from config |
| 487 | + num_heads = self.config.num_attention_heads |
| 488 | + head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads |
| 489 | + |
| 490 | + # Determine tensor format (default to sbhd if not specified) |
| 491 | + apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False) |
| 492 | + qkv_format = "bshd" if apply_rope_fusion else "sbhd" |
| 493 | + |
| 494 | + if qkv_format == "sbhd": |
| 495 | + dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16) |
| 496 | + else: |
| 497 | + dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16) |
| 498 | + |
| 499 | + # Calibrate each quantizer |
| 500 | + quantizers = [ |
| 501 | + ("q_bmm_quantizer", self.q_bmm_quantizer), |
| 502 | + ("k_bmm_quantizer", self.k_bmm_quantizer), |
| 503 | + ("v_bmm_quantizer", self.v_bmm_quantizer), |
| 504 | + ] |
| 505 | + |
| 506 | + for _, quantizer in quantizers: |
| 507 | + if quantizer is not None and quantizer.is_enabled: |
| 508 | + if not hasattr(quantizer, "_amax") or quantizer._amax is None: |
| 509 | + quantizer.reset_amax() |
| 510 | + max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) |
| 511 | + |
| 512 | + def forward(self, query, key, value, *args, **kwargs): |
| 513 | + """Apply post-RoPE quantization to KV cache. |
| 514 | + |
| 515 | + TEDotProductAttention receives Q, K, V after RoPE is applied, |
| 516 | + so we quantize them directly for KV cache quantization. |
| 517 | + """ |
| 518 | + # Quantize Q, K, V |
| 519 | + query = self.q_bmm_quantizer(query) |
| 520 | + key = self.k_bmm_quantizer(key) |
| 521 | + value = self.v_bmm_quantizer(value) |
| 522 | + |
| 523 | + return super().forward(query, key, value, *args, **kwargs) |
| 524 | + |
| 525 | + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): |
| 526 | + """Create a sharded state dictionary for distributed checkpointing.""" |
| 527 | + sharded_state_dict = {} |
| 528 | + |
| 529 | + # First add non-quantizer parameters |
| 530 | + for k, v in self.state_dict(prefix="", keep_vars=True).items(): |
| 531 | + if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: |
| 532 | + sharded_state_dict[prefix + k] = v |
| 533 | + |
| 534 | + # Process _amax in bmm_quantizers |
| 535 | + for name, quantizer in [ |
| 536 | + ("q_bmm_quantizer", self.q_bmm_quantizer), |
| 537 | + ("k_bmm_quantizer", self.k_bmm_quantizer), |
| 538 | + ("v_bmm_quantizer", self.v_bmm_quantizer), |
| 539 | + ]: |
| 540 | + if hasattr(quantizer, "_amax") and quantizer._amax is not None: |
| 541 | + amax_key = f"{prefix}{name}._amax" |
| 542 | + sharded_state_dict[amax_key] = quantizer._amax |
| 543 | + |
| 544 | + # Process other quantizer parameters in bmm_quantizers |
| 545 | + quantizer_state_dict = {} |
| 546 | + for k, v in self.state_dict(prefix="", keep_vars=True).items(): |
| 547 | + if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k: |
| 548 | + quantizer_state_dict[k] = v |
| 549 | + |
| 550 | + if quantizer_state_dict: |
| 551 | + sharded_state_dict.update( |
| 552 | + **make_sharded_tensors_for_checkpoint( |
| 553 | + quantizer_state_dict, prefix, {}, sharded_offsets |
| 554 | + ) |
| 555 | + ) |
| 556 | + |
| 557 | + return sharded_state_dict |
| 558 | + |
| 559 | + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
| 560 | + """Handle loading state dict for quantizers.""" |
| 561 | + for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: |
| 562 | + full_prefix = f"{prefix}{quantizer_name}." |
| 563 | + amax_key = f"{prefix}{quantizer_name}._amax" |
| 564 | + |
| 565 | + # If amax is in state_dict, rename it to the format expected by TensorQuantizer |
| 566 | + if amax_key in state_dict: |
| 567 | + expected_amax_key = f"{full_prefix}_amax" |
| 568 | + state_dict[expected_amax_key] = state_dict.pop(amax_key) |
| 569 | + |
| 570 | + # Handle other quantizer states |
| 571 | + for k in list(state_dict.keys()): |
| 572 | + if "_quantizer" in k and "_amax" not in k: |
| 573 | + name = k.split(prefix)[-1] if prefix else k |
| 574 | + if name in self.state_dict(): |
| 575 | + state_dict[k] = state_dict[k].view_as(self.state_dict()[name]) |
| 576 | + |
| 577 | + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
| 578 | + |
| 579 | + def modelopt_post_restore(self, name=""): |
| 580 | + """Restore quantizer states after model loading.""" |
| 581 | + super().modelopt_post_restore(name) |
| 582 | + |
| 583 | + def _check_unsupported_states(quantizer): |
| 584 | + if not hasattr(quantizer, "state_dict"): |
| 585 | + return |
| 586 | + |
| 587 | + for k in quantizer.state_dict().keys(): |
| 588 | + if k not in ["_amax", "_pre_quant_scale"]: |
| 589 | + warnings.warn( |
| 590 | + f"Restore of {k} for {name} is not supported. The restore of this layer might be " |
| 591 | + f"incorrect. Please implement a custom restore for {k}." |
| 592 | + ) |
| 593 | + |
| 594 | + calibration_needed = False |
| 595 | + |
| 596 | + for quantizer_name, quantizer in [ |
| 597 | + ("q_bmm_quantizer", self.q_bmm_quantizer), |
| 598 | + ("k_bmm_quantizer", self.k_bmm_quantizer), |
| 599 | + ("v_bmm_quantizer", self.v_bmm_quantizer), |
| 600 | + ]: |
| 601 | + if not hasattr(self, quantizer_name) or not quantizer.is_enabled: |
| 602 | + continue |
| 603 | + |
| 604 | + _check_unsupported_states(quantizer) |
| 605 | + |
| 606 | + if not hasattr(quantizer, "_amax") or quantizer._amax is None: |
| 607 | + calibration_needed = True |
| 608 | + |
| 609 | + if calibration_needed: |
| 610 | + self._calibrate_quantizers() |
0 commit comments