17
17
18
18
import pytest
19
19
import torch
20
- import torch .nn as nn
21
20
from _test_utils .import_helper import skip_if_no_megatron
22
21
from _test_utils .torch_dist .dist_utils import spawn_multiprocess_job
23
22
from _test_utils .torch_dist .plugins .megatron_common import (
@@ -374,8 +373,10 @@ def test_fp8_real_quantize():
374
373
375
374
def _test_kv_cache_quant_helper (config , rank , size ):
376
375
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
377
- initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED )
378
-
376
+ initialize_for_megatron (
377
+ tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED
378
+ )
379
+
379
380
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
380
381
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
381
382
model = get_mcore_gpt_model (
@@ -386,43 +387,45 @@ def _test_kv_cache_quant_helper(config, rank, size):
386
387
vocab_size = 32 ,
387
388
transformer_impl = "modelopt" , # This uses TEDotProductAttention via get_gpt_modelopt_spec
388
389
).cuda ()
389
-
390
+
390
391
# Create dummy input for calibration
391
392
prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
392
-
393
+
393
394
def forward_fn (model ):
394
395
return megatron_prefill (model , prompt_tokens )
395
-
396
+
396
397
# Test KV cache quantization with the given config
397
398
quantized_model = mtq .quantize (model , config , forward_fn )
398
-
399
+
399
400
# Find TEDotProductAttention modules and verify they have KV cache quantizers
400
401
te_attention_found = False
401
402
for name , module in quantized_model .named_modules ():
402
403
# Check if this is a quantized TEDotProductAttention
403
- if hasattr (module , ' q_bmm_quantizer' ) and hasattr (module , ' k_bmm_quantizer' ):
404
+ if hasattr (module , " q_bmm_quantizer" ) and hasattr (module , " k_bmm_quantizer" ):
404
405
te_attention_found = True
405
406
# Verify all expected quantizers exist
406
- assert hasattr (module , ' v_bmm_quantizer' ), f"Missing v_bmm_quantizer in { name } "
407
-
407
+ assert hasattr (module , " v_bmm_quantizer" ), f"Missing v_bmm_quantizer in { name } "
408
+
408
409
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
409
410
assert module .k_bmm_quantizer .is_enabled , f"K quantizer not enabled in { name } "
410
411
assert module .v_bmm_quantizer .is_enabled , f"V quantizer not enabled in { name } "
411
-
412
+
412
413
assert te_attention_found , "No TEDotProductAttention with KV cache quantizers found in model"
413
-
414
+
414
415
# Quick smoke test that forward still works
415
416
output = forward_fn (quantized_model )
416
417
assert output is not None , "Forward pass failed"
417
-
418
+
418
419
419
420
def _test_kv_cache_sharded_state_dict_helper (tmp_path , config , rank , size ):
420
421
"""Helper for testing KV cache quantization with sharded state dict save/load."""
421
422
# Disable output_layer quantization (same as other sharded state dict tests)
422
423
config ["quant_cfg" ]["*output_layer*" ] = {"enable" : False }
423
-
424
- initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED )
425
-
424
+
425
+ initialize_for_megatron (
426
+ tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED
427
+ )
428
+
426
429
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
427
430
model_ref = get_mcore_gpt_model (
428
431
tensor_model_parallel_size = size ,
@@ -432,7 +435,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
432
435
vocab_size = 64 ,
433
436
transformer_impl = "modelopt" , # CRITICAL: Use TEDotProductAttention
434
437
).cuda ()
435
-
438
+
436
439
model_test = get_mcore_gpt_model (
437
440
tensor_model_parallel_size = size ,
438
441
num_layers = 2 ,
@@ -441,29 +444,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
441
444
vocab_size = 64 ,
442
445
transformer_impl = "modelopt" ,
443
446
).cuda ()
444
-
445
- prompt_tokens = torch .randint (0 , model_ref .vocab_size , (2 , model_ref .max_sequence_length )).cuda ()
446
-
447
+
448
+ prompt_tokens = torch .randint (
449
+ 0 , model_ref .vocab_size , (2 , model_ref .max_sequence_length )
450
+ ).cuda ()
451
+
447
452
def forward_fn (model ):
448
453
return megatron_prefill (model , prompt_tokens )
449
-
454
+
450
455
# Quantize the reference model
451
456
model_ref = mtq .quantize (model_ref , config , forward_fn )
452
-
457
+
453
458
# CRITICAL: model_test must also be quantized with the same config
454
459
# Otherwise it won't have the KV cache quantizer keys when loading state dict
455
460
model_test = mtq .quantize (model_test , config , forward_fn )
456
-
461
+
457
462
# Verify KV cache quantizers were created
458
463
kv_quantizers_found = False
459
464
for name , module in model_ref .named_modules ():
460
- if hasattr (module , ' k_bmm_quantizer' ) and hasattr (module , ' v_bmm_quantizer' ):
465
+ if hasattr (module , " k_bmm_quantizer" ) and hasattr (module , " v_bmm_quantizer" ):
461
466
kv_quantizers_found = True
462
467
assert module .k_bmm_quantizer .is_enabled , f"K quantizer not enabled in { name } "
463
468
assert module .v_bmm_quantizer .is_enabled , f"V quantizer not enabled in { name } "
464
-
469
+
465
470
assert kv_quantizers_found , "No KV cache quantizers found in quantized model"
466
-
471
+
467
472
# Test sharded state dict save/load
468
473
sharded_state_dict_test_helper (
469
474
tmp_path ,
@@ -473,32 +478,38 @@ def forward_fn(model):
473
478
meta_device = False ,
474
479
version = None ,
475
480
)
476
-
481
+
477
482
# Verify KV cache quantizers are restored correctly in model_test
478
483
for (name_ref , module_ref ), (name_test , module_test ) in zip (
479
484
model_ref .named_modules (), model_test .named_modules ()
480
485
):
481
- if hasattr (module_ref , 'k_bmm_quantizer' ):
482
- assert hasattr (module_test , 'k_bmm_quantizer' ), f"K quantizer missing after restore in { name_test } "
483
- assert hasattr (module_test , 'v_bmm_quantizer' ), f"V quantizer missing after restore in { name_test } "
484
-
486
+ if hasattr (module_ref , "k_bmm_quantizer" ):
487
+ assert hasattr (module_test , "k_bmm_quantizer" ), (
488
+ f"K quantizer missing after restore in { name_test } "
489
+ )
490
+ assert hasattr (module_test , "v_bmm_quantizer" ), (
491
+ f"V quantizer missing after restore in { name_test } "
492
+ )
493
+
485
494
# Check that quantizer states match
486
- if hasattr (module_ref .k_bmm_quantizer , '_amax' ):
487
- assert hasattr (module_test .k_bmm_quantizer , '_amax' ), f"K quantizer _amax missing in { name_test } "
495
+ if hasattr (module_ref .k_bmm_quantizer , "_amax" ):
496
+ assert hasattr (module_test .k_bmm_quantizer , "_amax" ), (
497
+ f"K quantizer _amax missing in { name_test } "
498
+ )
488
499
if module_ref .k_bmm_quantizer ._amax is not None :
489
500
assert torch .allclose (
490
- module_ref .k_bmm_quantizer ._amax ,
491
- module_test .k_bmm_quantizer ._amax
501
+ module_ref .k_bmm_quantizer ._amax , module_test .k_bmm_quantizer ._amax
492
502
), f"K quantizer _amax mismatch in { name_test } "
493
-
494
- if hasattr (module_ref .v_bmm_quantizer , '_amax' ):
495
- assert hasattr (module_test .v_bmm_quantizer , '_amax' ), f"V quantizer _amax missing in { name_test } "
503
+
504
+ if hasattr (module_ref .v_bmm_quantizer , "_amax" ):
505
+ assert hasattr (module_test .v_bmm_quantizer , "_amax" ), (
506
+ f"V quantizer _amax missing in { name_test } "
507
+ )
496
508
if module_ref .v_bmm_quantizer ._amax is not None :
497
509
assert torch .allclose (
498
- module_ref .v_bmm_quantizer ._amax ,
499
- module_test .v_bmm_quantizer ._amax
510
+ module_ref .v_bmm_quantizer ._amax , module_test .v_bmm_quantizer ._amax
500
511
), f"V quantizer _amax mismatch in { name_test } "
501
-
512
+
502
513
503
514
@pytest .mark .parametrize (
504
515
"config" ,
@@ -509,16 +520,14 @@ def forward_fn(model):
509
520
)
510
521
def test_kv_cache_quant (config ):
511
522
"""Verify KV cache quantization works correctly with TEDotProductAttention.
512
-
513
- This test ensures TEDotProductAttention is properly registered and gets the
523
+
524
+ This test ensures TEDotProductAttention is properly registered and gets the
514
525
expected q/k/v_bmm_quantizers when using KV cache configs.
515
-
526
+
516
527
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
517
528
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
518
529
"""
519
- spawn_multiprocess_job (
520
- size = 1 , job = partial (_test_kv_cache_quant_helper , config ), backend = "nccl"
521
- )
530
+ spawn_multiprocess_job (size = 1 , job = partial (_test_kv_cache_quant_helper , config ), backend = "nccl" )
522
531
523
532
524
533
@pytest .mark .parametrize (
@@ -530,7 +539,7 @@ def test_kv_cache_quant(config):
530
539
)
531
540
def test_kv_cache_sharded_state_dict (tmp_path , config ):
532
541
"""Test KV cache quantization with sharded state dict save/load.
533
-
542
+
534
543
This test verifies the complete workflow of saving and loading KV cache quantized
535
544
models with distributed checkpointing, ensuring quantizer states are properly
536
545
preserved across the save/load cycle.
@@ -539,5 +548,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config):
539
548
spawn_multiprocess_job (
540
549
size = size ,
541
550
job = partial (_test_kv_cache_sharded_state_dict_helper , tmp_path , config ),
542
- backend = "nccl"
551
+ backend = "nccl" ,
543
552
)
0 commit comments