Skip to content

Commit 897ec7e

Browse files
committed
Add Float8Tensor
Summary: Added Float8Tensor that works for: * fbgemm: per row activation + per row weight calling torch.ops.fbgemm.f8f8bf16_rowwise kernela * aten: per row/tensor activation + per row/tensor weight calling torch._scaled_mm, or weight only quantization (fallback path) Reusing Float8DynamicActivationFloat8WeightConfig for the above, and use kernel to control which kernel users will use Test Plan: python test/dtypes/test_affine_quantized_float.py python test/quantization/quantize_/test_float8_rowwise_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2463, branch: jerryzh168/stack/9
1 parent cc359e6 commit 897ec7e

File tree

7 files changed

+763
-137
lines changed

7 files changed

+763
-137
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 38 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
from torch._inductor.test_case import TestCase as InductorTestCase
2626
from torch.testing._internal import common_utils
2727

28-
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
28+
from torchao.dtypes.floatx.float8_layout import preprocess_scale
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
3131
Float8DynamicActivationFloat8WeightConfig,
32+
Float8Tensor,
3233
float8_dynamic_activation_float8_weight,
3334
float8_weight_only,
3435
quantize_,
@@ -236,12 +237,8 @@ def test_serialization(self, mode: str):
236237
new_layer = getattr(new_model, layer_name)
237238

238239
# Compare weights
239-
if mode == "weight-only":
240-
original_weight = original_layer.weight.tensor_impl.float8_data.to(
241-
torch.float32
242-
)
243-
new_weight = new_layer.weight.tensor_impl.float8_data.to(torch.float32)
244-
else:
240+
if mode == "static":
241+
# note: we haven't migrated static quant to the new API
245242
original_weight = original_layer.weight.original_weight_tensor.tensor_impl.float8_data.to(
246243
torch.float32
247244
)
@@ -250,6 +247,9 @@ def test_serialization(self, mode: str):
250247
torch.float32
251248
)
252249
)
250+
else:
251+
original_weight = original_layer.weight.float8_data.to(torch.float32)
252+
new_weight = new_layer.weight.float8_data.to(torch.float32)
253253

254254
assert torch.allclose(original_weight, new_weight), (
255255
f"Weights do not match for {layer_name}"
@@ -324,19 +324,15 @@ def test_mm_float8dq_per_row(
324324

325325
quant_weight = test_linear.weight
326326

327-
self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
328-
weight_impl = quant_weight.original_weight_tensor.tensor_impl
329-
330-
self.assertTrue(hasattr(weight_impl, "float8_data"))
331-
self.assertTrue(hasattr(weight_impl, "scale"))
332-
self.assertFalse(weight_impl.transposed)
327+
self.assertTrue(hasattr(quant_weight, "float8_data"))
328+
self.assertTrue(hasattr(quant_weight, "scale"))
333329

334330
# Verify scale shape for row-wise quantization
335331
expected_scale_shape = (out_features, 1)
336-
actual_scale_shape = weight_impl.scale.shape
332+
actual_scale_shape = quant_weight.scale.shape
337333
self.assertEqual(actual_scale_shape, expected_scale_shape)
338334

339-
self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
335+
self.assertEqual(quant_weight.float8_data.shape, (out_features, in_features))
340336

341337
input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)
342338

@@ -357,7 +353,7 @@ def test_mm_float8dq_per_row(
357353
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
358354
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
359355
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
360-
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
356+
def test__dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
361357
"""Test _dequantize_affine_float8 with various configurations"""
362358

363359
device = "cuda"
@@ -387,7 +383,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
387383
@unittest.skipIf(
388384
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
389385
)
390-
def test_dequantize_affine_float8_scale_broadcasting(self):
386+
def test__dequantize_affine_float8_scale_broadcasting(self):
391387
"""Test that scale broadcasting works correctly for block-wise quantization"""
392388
device = "cuda"
393389
# Create input tensor with known block structure
@@ -431,24 +427,24 @@ def test_float8_tensor_slicing_basic(self, granularity):
431427
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
432428
)
433429

434-
weight_impl = model.weight.original_weight_tensor.tensor_impl
430+
weight = model.weight
435431

436432
# Test dimension 0 slicing (rows)
437-
sliced_0 = weight_impl[10:20]
433+
sliced_0 = weight[10:20]
438434
self.assertEqual(sliced_0.shape, (10, 64))
439435

440436
# Test dimension 1 slicing (columns)
441-
sliced_1 = weight_impl[:, 20:40]
437+
sliced_1 = weight[:, 20:40]
442438
self.assertEqual(sliced_1.shape, (32, 20))
443439

444440
# Test combined slicing
445-
sliced_both = weight_impl[5:15, 10:30]
441+
sliced_both = weight[5:15, 10:30]
446442
self.assertEqual(sliced_both.shape, (10, 20))
447443

448444
# Verify the sliced tensors are still Float8 tensors
449-
self.assertTrue(isinstance(sliced_0, Float8AQTTensorImpl))
450-
self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl))
451-
self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl))
445+
self.assertTrue(isinstance(sliced_0, Float8Tensor))
446+
self.assertTrue(isinstance(sliced_1, Float8Tensor))
447+
self.assertTrue(isinstance(sliced_both, Float8Tensor))
452448

453449
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
454450
@unittest.skipIf(
@@ -466,16 +462,15 @@ def test_float8_tensor_slicing_per_tensor(self):
466462
)
467463

468464
original_weight = model.weight
469-
original_impl = original_weight.original_weight_tensor.tensor_impl
470-
original_scale = original_impl.scale
465+
original_scale = original_weight.scale
471466

472467
# Test slicing
473468
sliced_weight = original_weight[10:20, 20:40]
474-
sliced_impl = sliced_weight.original_weight_tensor.tensor_impl
469+
sliced_scale = sliced_weight.scale
475470

476471
# For per-tensor quantization, scale should be identical
477-
self.assertTrue(torch.equal(original_scale, sliced_impl.scale))
478-
self.assertEqual(sliced_impl.scale.numel(), 1)
472+
self.assertTrue(torch.equal(original_scale, sliced_scale))
473+
self.assertEqual(sliced_scale.numel(), 1)
479474

480475
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
481476
@unittest.skipIf(
@@ -497,27 +492,26 @@ def test_float8_tensor_slicing_per_row(self):
497492
)
498493

499494
original_weight = model.weight # Shape: (32, 64)
500-
original_impl = original_weight.original_weight_tensor.tensor_impl
501-
original_scale = original_impl.scale # Shape: (32, 1)
495+
original_scale = model.weight.scale # Shape: (32, 1)
502496

503497
# Test row slicing (dimension 0)
504498
sliced_rows = original_weight[10:20] # Shape: (10, 64)
505-
sliced_impl = sliced_rows.original_weight_tensor.tensor_impl
499+
sliced_scale = sliced_rows.scale
506500

507501
# Scale should be sliced to match the rows
508502
expected_scale_shape = (10, 1)
509-
self.assertEqual(sliced_impl.scale.shape, expected_scale_shape)
503+
self.assertEqual(sliced_scale.shape, expected_scale_shape)
510504

511505
# Verify the scale values are correct (should be subset of original)
512-
self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20]))
506+
self.assertTrue(torch.equal(sliced_scale, original_scale[10:20]))
513507

514508
# Test column slicing (dimension 1) - scale should not change for per-row
515509
sliced_cols = original_weight[:, 20:40] # Shape: (32, 20)
516-
sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl
510+
sliced_cols_scale = sliced_cols.scale
517511

518512
# Scale shape should remain the same since we're not changing rows
519-
self.assertEqual(sliced_cols_impl.scale.shape, (32, 1))
520-
self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale))
513+
self.assertEqual(sliced_cols_scale.shape, (32, 1))
514+
self.assertTrue(torch.equal(sliced_cols_scale, original_scale))
521515

522516
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
523517
@unittest.skipIf(
@@ -552,11 +546,11 @@ def test_float8_tensor_slicing_edge_cases(self):
552546
@unittest.skipIf(
553547
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
554548
)
555-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
556549
@unittest.skipIf(
557550
is_sm_version(8, 9),
558551
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
559552
)
553+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
560554
def test_float8_tensor_slicing_functional_correctness(self, granularity):
561555
"""Test that sliced tensors produce correct results in computations"""
562556
device = "cuda"
@@ -579,15 +573,16 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
579573
quant_weight_slice = quant_model.weight[0:16, 0:32]
580574

581575
# Verify that the sliced weights maintain Float8 properties
582-
self.assertTrue(hasattr(quant_weight_slice, "original_weight_tensor"))
583-
sliced_impl = quant_weight_slice.original_weight_tensor.tensor_impl
584-
self.assertTrue(isinstance(sliced_impl, Float8AQTTensorImpl))
576+
self.assertTrue(hasattr(quant_weight_slice, "float8_data"))
577+
self.assertTrue(hasattr(quant_weight_slice, "scale"))
578+
sliced_impl = quant_weight_slice
579+
self.assertTrue(isinstance(sliced_impl, Float8Tensor))
585580

586581
# Verify sliced weight shapes
587582
self.assertEqual(sliced_impl.float8_data.shape, (16, 32))
588583

589584
# Get original quantized weight implementation for scale comparison
590-
original_quant_impl = quant_model.weight.original_weight_tensor.tensor_impl
585+
original_quant_impl = quant_model.weight
591586

592587
# Verify scale properties based on granularity
593588
if isinstance(granularity, PerTensor):
@@ -604,7 +599,7 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
604599
)
605600

606601
# Verify that sliced quantized data matches the correct slice from original
607-
original_float8_data_slice = original_quant_impl.float8_data[0:16, 0:32]
602+
original_float8_data_slice = quant_model.weight.float8_data[0:16, 0:32]
608603
self.assertTrue(
609604
torch.equal(sliced_impl.float8_data, original_float8_data_slice)
610605
)
@@ -675,46 +670,6 @@ def test_preprocess_scale_3d_reshape(self):
675670
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676671
self.assertEqual(result.shape, expected_shape)
677672

678-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
679-
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
680-
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
681-
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
682-
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
683-
input = torch.randn(10, 10)
684-
with torch.no_grad():
685-
torch._dynamo.reset()
686-
expected_scale = torch.tensor(2.0)
687-
expected_quantized = quantize_affine_float8(
688-
input,
689-
expected_scale,
690-
float8_dtype=float8_dtype,
691-
)
692-
expected_dequantized = dequantize_affine_float8(
693-
expected_quantized,
694-
expected_scale,
695-
output_dtype=hp_dtype,
696-
)
697-
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
698-
torch.compile(quantize_affine_float8),
699-
input,
700-
expected_scale,
701-
float8_dtype=float8_dtype,
702-
)
703-
torch.testing.FileCheck().check(
704-
"torch.ops.torchao.quantize_affine_float8.default"
705-
).run(code_q)
706-
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
707-
torch.compile(dequantize_affine_float8),
708-
test_q,
709-
expected_scale,
710-
hp_dtype,
711-
)
712-
torch.testing.FileCheck().check(
713-
"torch.ops.torchao.dequantize_affine_float8.default"
714-
).run(code_dq)
715-
torch.testing.assert_close(expected_quantized, test_q)
716-
torch.testing.assert_close(expected_dequantized, test_dq)
717-
718673

719674
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
720675

0 commit comments

Comments
 (0)