25
25
from torch ._inductor .test_case import TestCase as InductorTestCase
26
26
from torch .testing ._internal import common_utils
27
27
28
- from torchao .dtypes .floatx .float8_layout import Float8AQTTensorImpl , preprocess_scale
28
+ from torchao .dtypes .floatx .float8_layout import preprocess_scale
29
29
from torchao .float8 .float8_utils import compute_error
30
30
from torchao .quantization import (
31
31
Float8DynamicActivationFloat8WeightConfig ,
32
+ Float8Tensor ,
32
33
float8_dynamic_activation_float8_weight ,
33
34
float8_weight_only ,
34
35
quantize_ ,
@@ -236,12 +237,8 @@ def test_serialization(self, mode: str):
236
237
new_layer = getattr (new_model , layer_name )
237
238
238
239
# Compare weights
239
- if mode == "weight-only" :
240
- original_weight = original_layer .weight .tensor_impl .float8_data .to (
241
- torch .float32
242
- )
243
- new_weight = new_layer .weight .tensor_impl .float8_data .to (torch .float32 )
244
- else :
240
+ if mode == "static" :
241
+ # note: we haven't migrated static quant to the new API
245
242
original_weight = original_layer .weight .original_weight_tensor .tensor_impl .float8_data .to (
246
243
torch .float32
247
244
)
@@ -250,6 +247,9 @@ def test_serialization(self, mode: str):
250
247
torch .float32
251
248
)
252
249
)
250
+ else :
251
+ original_weight = original_layer .weight .float8_data .to (torch .float32 )
252
+ new_weight = new_layer .weight .float8_data .to (torch .float32 )
253
253
254
254
assert torch .allclose (original_weight , new_weight ), (
255
255
f"Weights do not match for { layer_name } "
@@ -324,19 +324,15 @@ def test_mm_float8dq_per_row(
324
324
325
325
quant_weight = test_linear .weight
326
326
327
- self .assertTrue (hasattr (quant_weight , "original_weight_tensor" ))
328
- weight_impl = quant_weight .original_weight_tensor .tensor_impl
329
-
330
- self .assertTrue (hasattr (weight_impl , "float8_data" ))
331
- self .assertTrue (hasattr (weight_impl , "scale" ))
332
- self .assertFalse (weight_impl .transposed )
327
+ self .assertTrue (hasattr (quant_weight , "float8_data" ))
328
+ self .assertTrue (hasattr (quant_weight , "scale" ))
333
329
334
330
# Verify scale shape for row-wise quantization
335
331
expected_scale_shape = (out_features , 1 )
336
- actual_scale_shape = weight_impl .scale .shape
332
+ actual_scale_shape = quant_weight .scale .shape
337
333
self .assertEqual (actual_scale_shape , expected_scale_shape )
338
334
339
- self .assertEqual (weight_impl .float8_data .shape , (out_features , in_features ))
335
+ self .assertEqual (quant_weight .float8_data .shape , (out_features , in_features ))
340
336
341
337
input_tensor = torch .randn (* input_shape , device = device , dtype = dtype )
342
338
@@ -357,7 +353,7 @@ def test_mm_float8dq_per_row(
357
353
@common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
358
354
@common_utils .parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
359
355
@common_utils .parametrize ("block_size" , [None , (1 , 32 ), (2 , 16 ), (4 , 8 )])
360
- def test_dequantize_affine_float8 (self , float8_dtype , output_dtype , block_size ):
356
+ def test__dequantize_affine_float8 (self , float8_dtype , output_dtype , block_size ):
361
357
"""Test _dequantize_affine_float8 with various configurations"""
362
358
363
359
device = "cuda"
@@ -387,7 +383,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
387
383
@unittest .skipIf (
388
384
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
389
385
)
390
- def test_dequantize_affine_float8_scale_broadcasting (self ):
386
+ def test__dequantize_affine_float8_scale_broadcasting (self ):
391
387
"""Test that scale broadcasting works correctly for block-wise quantization"""
392
388
device = "cuda"
393
389
# Create input tensor with known block structure
@@ -431,24 +427,24 @@ def test_float8_tensor_slicing_basic(self, granularity):
431
427
model , Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
432
428
)
433
429
434
- weight_impl = model .weight . original_weight_tensor . tensor_impl
430
+ weight = model .weight
435
431
436
432
# Test dimension 0 slicing (rows)
437
- sliced_0 = weight_impl [10 :20 ]
433
+ sliced_0 = weight [10 :20 ]
438
434
self .assertEqual (sliced_0 .shape , (10 , 64 ))
439
435
440
436
# Test dimension 1 slicing (columns)
441
- sliced_1 = weight_impl [:, 20 :40 ]
437
+ sliced_1 = weight [:, 20 :40 ]
442
438
self .assertEqual (sliced_1 .shape , (32 , 20 ))
443
439
444
440
# Test combined slicing
445
- sliced_both = weight_impl [5 :15 , 10 :30 ]
441
+ sliced_both = weight [5 :15 , 10 :30 ]
446
442
self .assertEqual (sliced_both .shape , (10 , 20 ))
447
443
448
444
# Verify the sliced tensors are still Float8 tensors
449
- self .assertTrue (isinstance (sliced_0 , Float8AQTTensorImpl ))
450
- self .assertTrue (isinstance (sliced_1 , Float8AQTTensorImpl ))
451
- self .assertTrue (isinstance (sliced_both , Float8AQTTensorImpl ))
445
+ self .assertTrue (isinstance (sliced_0 , Float8Tensor ))
446
+ self .assertTrue (isinstance (sliced_1 , Float8Tensor ))
447
+ self .assertTrue (isinstance (sliced_both , Float8Tensor ))
452
448
453
449
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
454
450
@unittest .skipIf (
@@ -466,16 +462,15 @@ def test_float8_tensor_slicing_per_tensor(self):
466
462
)
467
463
468
464
original_weight = model .weight
469
- original_impl = original_weight .original_weight_tensor .tensor_impl
470
- original_scale = original_impl .scale
465
+ original_scale = original_weight .scale
471
466
472
467
# Test slicing
473
468
sliced_weight = original_weight [10 :20 , 20 :40 ]
474
- sliced_impl = sliced_weight .original_weight_tensor . tensor_impl
469
+ sliced_scale = sliced_weight .scale
475
470
476
471
# For per-tensor quantization, scale should be identical
477
- self .assertTrue (torch .equal (original_scale , sliced_impl . scale ))
478
- self .assertEqual (sliced_impl . scale .numel (), 1 )
472
+ self .assertTrue (torch .equal (original_scale , sliced_scale ))
473
+ self .assertEqual (sliced_scale .numel (), 1 )
479
474
480
475
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
481
476
@unittest .skipIf (
@@ -497,27 +492,26 @@ def test_float8_tensor_slicing_per_row(self):
497
492
)
498
493
499
494
original_weight = model .weight # Shape: (32, 64)
500
- original_impl = original_weight .original_weight_tensor .tensor_impl
501
- original_scale = original_impl .scale # Shape: (32, 1)
495
+ original_scale = model .weight .scale # Shape: (32, 1)
502
496
503
497
# Test row slicing (dimension 0)
504
498
sliced_rows = original_weight [10 :20 ] # Shape: (10, 64)
505
- sliced_impl = sliced_rows .original_weight_tensor . tensor_impl
499
+ sliced_scale = sliced_rows .scale
506
500
507
501
# Scale should be sliced to match the rows
508
502
expected_scale_shape = (10 , 1 )
509
- self .assertEqual (sliced_impl . scale .shape , expected_scale_shape )
503
+ self .assertEqual (sliced_scale .shape , expected_scale_shape )
510
504
511
505
# Verify the scale values are correct (should be subset of original)
512
- self .assertTrue (torch .equal (sliced_impl . scale , original_scale [10 :20 ]))
506
+ self .assertTrue (torch .equal (sliced_scale , original_scale [10 :20 ]))
513
507
514
508
# Test column slicing (dimension 1) - scale should not change for per-row
515
509
sliced_cols = original_weight [:, 20 :40 ] # Shape: (32, 20)
516
- sliced_cols_impl = sliced_cols .original_weight_tensor . tensor_impl
510
+ sliced_cols_scale = sliced_cols .scale
517
511
518
512
# Scale shape should remain the same since we're not changing rows
519
- self .assertEqual (sliced_cols_impl . scale .shape , (32 , 1 ))
520
- self .assertTrue (torch .equal (sliced_cols_impl . scale , original_scale ))
513
+ self .assertEqual (sliced_cols_scale .shape , (32 , 1 ))
514
+ self .assertTrue (torch .equal (sliced_cols_scale , original_scale ))
521
515
522
516
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
523
517
@unittest .skipIf (
@@ -552,11 +546,11 @@ def test_float8_tensor_slicing_edge_cases(self):
552
546
@unittest .skipIf (
553
547
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
554
548
)
555
- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
556
549
@unittest .skipIf (
557
550
is_sm_version (8 , 9 ),
558
551
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15" ,
559
552
)
553
+ @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
560
554
def test_float8_tensor_slicing_functional_correctness (self , granularity ):
561
555
"""Test that sliced tensors produce correct results in computations"""
562
556
device = "cuda"
@@ -579,15 +573,16 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
579
573
quant_weight_slice = quant_model .weight [0 :16 , 0 :32 ]
580
574
581
575
# Verify that the sliced weights maintain Float8 properties
582
- self .assertTrue (hasattr (quant_weight_slice , "original_weight_tensor" ))
583
- sliced_impl = quant_weight_slice .original_weight_tensor .tensor_impl
584
- self .assertTrue (isinstance (sliced_impl , Float8AQTTensorImpl ))
576
+ self .assertTrue (hasattr (quant_weight_slice , "float8_data" ))
577
+ self .assertTrue (hasattr (quant_weight_slice , "scale" ))
578
+ sliced_impl = quant_weight_slice
579
+ self .assertTrue (isinstance (sliced_impl , Float8Tensor ))
585
580
586
581
# Verify sliced weight shapes
587
582
self .assertEqual (sliced_impl .float8_data .shape , (16 , 32 ))
588
583
589
584
# Get original quantized weight implementation for scale comparison
590
- original_quant_impl = quant_model .weight . original_weight_tensor . tensor_impl
585
+ original_quant_impl = quant_model .weight
591
586
592
587
# Verify scale properties based on granularity
593
588
if isinstance (granularity , PerTensor ):
@@ -604,7 +599,7 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
604
599
)
605
600
606
601
# Verify that sliced quantized data matches the correct slice from original
607
- original_float8_data_slice = original_quant_impl .float8_data [0 :16 , 0 :32 ]
602
+ original_float8_data_slice = quant_model . weight .float8_data [0 :16 , 0 :32 ]
608
603
self .assertTrue (
609
604
torch .equal (sliced_impl .float8_data , original_float8_data_slice )
610
605
)
@@ -675,46 +670,6 @@ def test_preprocess_scale_3d_reshape(self):
675
670
expected_shape = (8 , 1 ) # Flattened (2*2*2, 1)
676
671
self .assertEqual (result .shape , expected_shape )
677
672
678
- @common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
679
- @common_utils .parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
680
- def test_quantize_dequantize_fp8_inductor (self , float8_dtype , hp_dtype ):
681
- quantize_affine_float8 = torch .ops .torchao .quantize_affine_float8
682
- dequantize_affine_float8 = torch .ops .torchao .dequantize_affine_float8
683
- input = torch .randn (10 , 10 )
684
- with torch .no_grad ():
685
- torch ._dynamo .reset ()
686
- expected_scale = torch .tensor (2.0 )
687
- expected_quantized = quantize_affine_float8 (
688
- input ,
689
- expected_scale ,
690
- float8_dtype = float8_dtype ,
691
- )
692
- expected_dequantized = dequantize_affine_float8 (
693
- expected_quantized ,
694
- expected_scale ,
695
- output_dtype = hp_dtype ,
696
- )
697
- test_q , (code_q ,) = torch ._inductor .utils .run_and_get_code (
698
- torch .compile (quantize_affine_float8 ),
699
- input ,
700
- expected_scale ,
701
- float8_dtype = float8_dtype ,
702
- )
703
- torch .testing .FileCheck ().check (
704
- "torch.ops.torchao.quantize_affine_float8.default"
705
- ).run (code_q )
706
- test_dq , (code_dq ,) = torch ._inductor .utils .run_and_get_code (
707
- torch .compile (dequantize_affine_float8 ),
708
- test_q ,
709
- expected_scale ,
710
- hp_dtype ,
711
- )
712
- torch .testing .FileCheck ().check (
713
- "torch.ops.torchao.dequantize_affine_float8.default"
714
- ).run (code_dq )
715
- torch .testing .assert_close (expected_quantized , test_q )
716
- torch .testing .assert_close (expected_dequantized , test_dq )
717
-
718
673
719
674
common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
720
675
0 commit comments