diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 879551fc0a..079ba5d05f 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -25,6 +25,7 @@ from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils +from torchao.dtypes import FbgemmFp8Tensor from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale from torchao.float8.float8_utils import compute_error from torchao.quantization import ( @@ -324,19 +325,15 @@ def test_mm_float8dq_per_row( quant_weight = test_linear.weight - self.assertTrue(hasattr(quant_weight, "original_weight_tensor")) - weight_impl = quant_weight.original_weight_tensor.tensor_impl - - self.assertTrue(hasattr(weight_impl, "float8_data")) - self.assertTrue(hasattr(weight_impl, "scale")) - self.assertFalse(weight_impl.transposed) + self.assertTrue(hasattr(quant_weight, "float8_data")) + self.assertTrue(hasattr(quant_weight, "scale")) # Verify scale shape for row-wise quantization expected_scale_shape = (out_features, 1) - actual_scale_shape = weight_impl.scale.shape + actual_scale_shape = quant_weight.scale.shape self.assertEqual(actual_scale_shape, expected_scale_shape) - self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features)) + self.assertEqual(quant_weight.float8_data.shape, (out_features, in_features)) input_tensor = torch.randn(*input_shape, device=device, dtype=dtype) @@ -419,11 +416,11 @@ def test_dequantize_affine_float8_scale_broadcasting(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - def test_float8_tensor_slicing_basic(self, granularity): + def test_float8_tensor_slicing_basic_per_tensor(self): """Test basic slicing operations on Float8 tensors""" device = "cuda" dtype = torch.bfloat16 + granularity = PerTensor() # Create and quantize a model model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) @@ -450,6 +447,41 @@ def test_float8_tensor_slicing_basic(self, granularity): self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl)) self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl)) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + def test_float8_tensor_slicing_basic_per_row(self): + """Test basic slicing operations on Float8 tensors""" + device = "cuda" + dtype = torch.bfloat16 + granularity = PerRow() + + # Create and quantize a model + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + quantize_( + model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + ) + + weight = model.weight + + # Test dimension 0 slicing (rows) + sliced_0 = weight[10:20] + self.assertEqual(sliced_0.shape, (10, 64)) + + # Test dimension 1 slicing (columns) + sliced_1 = weight[:, 20:40] + self.assertEqual(sliced_1.shape, (32, 20)) + + # Test combined slicing + sliced_both = weight[5:15, 10:30] + self.assertEqual(sliced_both.shape, (10, 20)) + + # Verify the sliced tensors are still Float8 tensors + self.assertTrue(isinstance(sliced_0, FbgemmFp8Tensor)) + self.assertTrue(isinstance(sliced_1, FbgemmFp8Tensor)) + self.assertTrue(isinstance(sliced_both, FbgemmFp8Tensor)) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" @@ -497,27 +529,26 @@ def test_float8_tensor_slicing_per_row(self): ) original_weight = model.weight # Shape: (32, 64) - original_impl = original_weight.original_weight_tensor.tensor_impl - original_scale = original_impl.scale # Shape: (32, 1) + original_scale = model.weight.scale # Shape: (32, 1) # Test row slicing (dimension 0) sliced_rows = original_weight[10:20] # Shape: (10, 64) - sliced_impl = sliced_rows.original_weight_tensor.tensor_impl + sliced_scale = sliced_rows.scale # Scale should be sliced to match the rows expected_scale_shape = (10, 1) - self.assertEqual(sliced_impl.scale.shape, expected_scale_shape) + self.assertEqual(sliced_scale.shape, expected_scale_shape) # Verify the scale values are correct (should be subset of original) - self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20])) + self.assertTrue(torch.equal(sliced_scale, original_scale[10:20])) # Test column slicing (dimension 1) - scale should not change for per-row sliced_cols = original_weight[:, 20:40] # Shape: (32, 20) - sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl + sliced_cols_scale = sliced_cols.scale # Scale shape should remain the same since we're not changing rows - self.assertEqual(sliced_cols_impl.scale.shape, (32, 1)) - self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale)) + self.assertEqual(sliced_cols_scale.shape, (32, 1)) + self.assertTrue(torch.equal(sliced_cols_scale, original_scale)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( @@ -552,15 +583,15 @@ def test_float8_tensor_slicing_edge_cases(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @unittest.skipIf( is_sm_version(8, 9), "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15", ) - def test_float8_tensor_slicing_functional_correctness(self, granularity): + def test_float8_tensor_slicing_functional_correctness_per_tensor(self): """Test that sliced tensors produce correct results in computations""" device = "cuda" dtype = torch.bfloat16 + granularity = PerTensor() # Create reference and quantized models with dimensions that are multiples of 16 ref_model = ( @@ -630,6 +661,89 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity): error = compute_error(ref_output, quant_output) self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @unittest.skipIf( + is_sm_version(8, 9), + "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15", + ) + def test_float8_tensor_slicing_functional_correctness_per_row(self): + """Test that sliced tensors produce correct results in computations""" + device = "cuda" + dtype = torch.bfloat16 + granularity = PerRow() + + # Create reference and quantized models with dimensions that are multiples of 16 + ref_model = ( + torch.nn.Linear(64, 48, bias=False).to(device).to(dtype) + ) # 48 is divisible by 16 + quant_model = copy.deepcopy(ref_model) + quantize_( + quant_model, + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), + ) + + # Create input with batch size that works well with slicing + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) + + ref_weight_slice = ref_model.weight[0:16, 0:32] + quant_weight_slice = quant_model.weight[0:16, 0:32] + + # Verify that the sliced weights maintain Float8 properties + self.assertTrue(hasattr(quant_weight_slice, "float8_data")) + self.assertTrue(hasattr(quant_weight_slice, "scale")) + sliced_impl = quant_weight_slice + self.assertTrue(isinstance(sliced_impl, FbgemmFp8Tensor)) + + # Verify sliced weight shapes + self.assertEqual(sliced_impl.float8_data.shape, (16, 32)) + + # Get original quantized weight implementation for scale comparison + original_quant_impl = quant_model.weight + + # Verify scale properties based on granularity + if isinstance(granularity, PerTensor): + # Per-tensor: scale should be identical to original (scalar) + self.assertEqual(sliced_impl.scale.numel(), 1) + self.assertTrue(torch.equal(sliced_impl.scale, original_quant_impl.scale)) + else: # PerRow + # Per-row: scale should be sliced to match the selected rows (0:16) + expected_scale_shape = (16, 1) + self.assertEqual(sliced_impl.scale.shape, expected_scale_shape) + # Verify the scale values are the correct slice from the original + self.assertTrue( + torch.equal(sliced_impl.scale, original_quant_impl.scale[0:16]) + ) + + # Verify that sliced quantized data matches the correct slice from original + original_float8_data_slice = quant_model.weight.float8_data[0:16, 0:32] + self.assertTrue( + torch.equal(sliced_impl.float8_data, original_float8_data_slice) + ) + + # Verify that sliced weights can be converted back to float with correct values + sliced_float_weight = quant_weight_slice.to(dtype) + self.assertEqual(sliced_float_weight.shape, (16, 32)) + self.assertEqual(sliced_float_weight.dtype, dtype) + + input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight + + # Compute with sliced weights + with torch.no_grad(): + ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice) + quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice) + + # Verify shapes + expected_shape = (8, 16) # batch_size x out_features_sliced + self.assertEqual(ref_output.shape, expected_shape) + self.assertEqual(quant_output.shape, expected_shape) + + # Verify reasonable quantization error + error = compute_error(ref_output, quant_output) + self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") + def test_preprocess_scale_3d_reshape(self): """Test that preprocess_scale correctly handles 3D scale tensors""" device = "cpu" # Use CPU for basic functionality test diff --git a/test/dtypes/test_fbgemm_fp8.py b/test/dtypes/test_fbgemm_fp8.py index 1e681d00f9..f5a8030b7c 100644 --- a/test/dtypes/test_fbgemm_fp8.py +++ b/test/dtypes/test_fbgemm_fp8.py @@ -9,12 +9,14 @@ import torch from torch.testing._internal.common_utils import ( TestCase, + instantiate_parametrized_tests, + parametrize, run_tests, ) -from torchao.float8.config import e4m3_dtype from torchao.quantization import ( - FbgemmConfig, + Float8DynamicActivationFloat8WeightConfig, + PerRow, quantize_, ) from torchao.quantization.utils import compute_error @@ -23,36 +25,35 @@ is_sm_at_least_90, ) +FBGEMM_CONFIG = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), kernel="fbgemm" +) +ATEN_CONFIG = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), kernel="aten" +) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") class TestFbgemmFp8Tensor(TestCase): def setUp(self): - self.config = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=e4m3_dtype, - output_dtype=torch.bfloat16, - ) - self.bmm_config = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=e4m3_dtype, - output_dtype=torch.bfloat16, - transpose_input=True, - ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] - def test_linear(self): + @parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG]) + def test_linear(self, config): dtype = torch.bfloat16 device = "cuda" input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - quantize_(linear, self.config) + quantize_(linear, config) quantized = linear(input) - self.assertTrue(compute_error(original, quantized) > 20) + sqnr = compute_error(original, quantized) + self.assertTrue(sqnr > 20, f"sqnr: {sqnr}") - def test_slice(self): + @parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG]) + def test_slice(self, config): dtype = torch.bfloat16 device = "cuda" dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) @@ -65,7 +66,7 @@ def test_slice(self): dummy.weight.narrow(1, 0, 128), requires_grad=False ) - quantize_(dummy, self.config) + quantize_(dummy, config) weight1 = dummy.weight.narrow(0, 0, 64) weight2 = dummy.weight.narrow(1, 0, 128) self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64)) @@ -81,20 +82,23 @@ def test_slice(self): res_ref = dummy1(input) dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) res = dummy(input) - assert compute_error(res, res_ref) > 25 + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 25, f"sqnr: {sqnr}") input = torch.randn(2, 128, dtype=dtype, device=device) res_ref = dummy2(input) dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) res = dummy(input) - assert compute_error(res, res_ref) > 15 + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") - def test_slice_and_copy_(self): + @parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG]) + def test_slice_and_copy_(self, config): l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) l.weight = torch.nn.Parameter( torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") ) - quantize_(l, self.config) + quantize_(l, config) param = l.weight param_data = param.data param_data = param_data.narrow(0, 0, 512) @@ -104,7 +108,7 @@ def test_slice_and_copy_(self): # dummy_l has random input (shouldn't be 0) dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - quantize_(dummy_l, self.config) + quantize_(dummy_l, config) quantized = dummy_l.weight quantized = quantized.narrow(0, 0, 512) @@ -113,7 +117,8 @@ def test_slice_and_copy_(self): # making sure param.data is updated assert param.data.float8_data[0][0] != orig_value - def test_bmm(self): + @parametrize("config", [FBGEMM_CONFIG]) + def test_bmm(self, config): class M(torch.nn.Module): def __init__(self, weight): super().__init__() @@ -128,24 +133,80 @@ def forward(self, x): weight = torch.randn(10, 128, 256, dtype=dtype, device=device) m = M(weight).eval() original = m(input) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + # we need to transpose the weight first for bmm + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) + quantize_(m, config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 20) - def test_to_device(self): + @parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG]) + def test_to_device(self, config): for device in self.GPU_DEVICES: linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device=device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) + @parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG]) + def test_cat(self, config): + dtype = torch.bfloat16 + device = "cuda" + # weight: (256, 128) + linear1 = torch.nn.Linear(128, 256, dtype=dtype) + # weight: (256, 128) + linear2 = torch.nn.Linear(128, 256, dtype=dtype) + + cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device) + + dummy1.weight = torch.nn.Parameter(cat_weight1) + quantize_(dummy1, config) + + quantize_(linear1, config) + quantize_(linear2, config) + + cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + self.assertTrue(cat_qweight1.shape, (512, 128)) + self.assertEqual(dummy1.weight.float8_data, cat_qweight1.float8_data) + self.assertEqual(dummy1.weight.scale, cat_qweight1.scale) + + # concat with dim == 1 is not really correct and will be fixed later + # when we support distributed checkpointing + cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1) + self.assertTrue(cat_qweight2.shape, (256, 256)) + ref_float8_data = torch.cat( + [linear1.weight.float8_data, linear2.weight.float8_data], dim=1 + ) + ref_scale = linear1.weight.scale + self.assertEqual(cat_qweight2.float8_data, ref_float8_data) + self.assertEqual(cat_qweight2.scale, ref_scale) + + @parametrize("config", [FBGEMM_CONFIG]) + def test_transpose(self, config): + dtype = torch.bfloat16 + device = "cuda" + # weight: (256, 128) + linear1 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + quantize_(linear1, config) + linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous()) + linear1.bias = torch.nn.Parameter(torch.randn(128, dtype=dtype, device=device)) + self.assertTrue(linear1.weight.shape, (128, 256)) + + input = torch.randn(32, 256, dtype=dtype, device=device) + # make sure it runs + res = linear1(input) + self.assertTrue(res.shape, (32, 128)) + + +instantiate_parametrized_tests(TestFbgemmFp8Tensor) + if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_fbgemm_int4.py b/test/dtypes/test_fbgemm_int4.py index cba9d81ae0..12598fdeab 100644 --- a/test/dtypes/test_fbgemm_int4.py +++ b/test/dtypes/test_fbgemm_int4.py @@ -39,7 +39,6 @@ def setUp(self): weight_dtype=torch.int4, output_dtype=torch.bfloat16, block_size=[1, 1, 128], - transpose_input=True, ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] @@ -134,6 +133,7 @@ def forward(self, x): weight = torch.randn(10, 128, 256, dtype=dtype, device=device) m = M(weight).eval() original = m(input) + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) @@ -152,6 +152,53 @@ def test_to_device(self): quantize_(linear, self.config) linear.to(device) + def test_cat(self): + dtype = torch.bfloat16 + device = "cuda" + # weight: (256, 128) + linear1 = torch.nn.Linear(128, 256, dtype=dtype) + # weight: (256, 128) + linear2 = torch.nn.Linear(128, 256, dtype=dtype) + + cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1) + dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device) + dummy2 = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + + dummy1.weight = torch.nn.Parameter(cat_weight1) + dummy2.weight = torch.nn.Parameter(cat_weight2) + quantize_(dummy1, self.config) + quantize_(dummy2, self.config) + + quantize_(linear1, self.config) + quantize_(linear2, self.config) + + cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + self.assertTrue(cat_qweight1.shape, (512, 128)) + self.assertEqual(dummy1.weight.packed_weight, cat_qweight1.packed_weight) + self.assertEqual(dummy1.weight.scale, cat_qweight1.scale) + self.assertEqual(dummy1.weight.zero_point, cat_qweight1.zero_point) + + cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1) + self.assertTrue(cat_qweight2.shape, (256, 256)) + self.assertEqual(dummy2.weight.packed_weight, cat_qweight2.packed_weight) + self.assertEqual(dummy2.weight.scale, cat_qweight2.scale) + self.assertEqual(dummy2.weight.zero_point, cat_qweight2.zero_point) + + def test_transpose(self): + # weight: (256, 128) + linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + quantize_(linear1, self.config) + linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous()) + # transpose again to return to the original state + linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous()) + self.assertTrue(linear1.weight.shape, (256, 128)) + + input = torch.randn(32, 128, dtype=torch.bfloat16, device="cuda") + # make sure it runs + res = linear1(input) + self.assertTrue(res.shape, (32, 256)) + if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_int4_groupwise_preshuffle.py b/test/dtypes/test_int4_groupwise_preshuffle.py new file mode 100644 index 0000000000..7aeadaafa2 --- /dev/null +++ b/test/dtypes/test_int4_groupwise_preshuffle.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization import ( + FbgemmConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, + is_sm_at_least_90, +) + + +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") +class TestInt4GroupwisePreshuffleTensor(TestCase): + def setUp(self): + self.config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + preshuffle=True, + ) + self.bmm_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + preshuffle=True, + ) + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + def test_linear(self): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, self.config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + @unittest.skip("WIP: this doesn't work yet") + def test_slice(self): + dtype = torch.bfloat16 + device = "cuda" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, self.config) + weight1 = dummy.weight.narrow(0, 0, 64) + weight2 = dummy.weight.narrow(1, 0, 128) + self.assertEqual( + weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64) + ) + self.assertEqual(weight1.group_scale, dummy.weight.group_scale.narrow(1, 0, 64)) + self.assertEqual( + weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64) + ) + self.assertEqual(weight2.group_scale, dummy.weight.group_scale.narrow(0, 0, 1)) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + assert sqnr > 20, f"Got: {sqnr}" + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + assert sqnr > 15, f"Got: {sqnr}" + + def test_slice_and_copy_(self): + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, self.config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + assert ( + param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr() + ) + assert param.data.group_scale.data_ptr() == param_data.group_scale.data_ptr() + assert param.data.row_scale.data_ptr() == param_data.row_scale.data_ptr() + orig_value = param.data.packed_weight[0][0].item() + + # dummy_l has random input (shouldn't be 0) + dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + quantize_(dummy_l, self.config) + quantized = dummy_l.weight + quantized = quantized.narrow(0, 0, 512) + + param_data.copy_(quantized) + + # making sure param.data is updated + assert param.data.packed_weight[0][0] != orig_value + + def test_bmm(self): + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(10, 32, 128, dtype=dtype, device=device) + weight = torch.randn(10, 128, 256, dtype=dtype, device=device) + m = M(weight).eval() + original = m(input) + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) + quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantized = m(input) + self.assertTrue(compute_error(original, quantized) > 18) + + def test_to_device(self): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/core/config.py b/torchao/core/config.py index 3451b90c59..03bcee2e3b 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -12,6 +12,14 @@ import torch +__all__ = [ + "AOBaseConfig", + "VersionMismatchError", + "config_to_dict", + "config_from_dict", + "ALLOWED_AO_MODULES", +] + class AOBaseConfig(abc.ABC): """ diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 581c3e4ecb..36fb3c3f72 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,6 +14,10 @@ CutlassSemiSparseLayout, Float8Layout, ) +from .int4_groupwise_preshuffle_tensor import ( + Int4GroupwisePreshuffleTensor, + to_int4_groupwise_preshuffle, +) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( BlockSparseLayout, @@ -67,4 +71,6 @@ "FbgemmInt4Tensor", "to_fbgemm_fp8", "FbgemmFp8Tensor", + "Int4GroupwisePreshuffleTensor", + "to_int4_groupwise_preshuffle", ] diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py index b6c1d72acc..72523d8a43 100644 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ b/torchao/dtypes/fbgemm_fp8_tensor.py @@ -10,6 +10,21 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.floatx.float8_layout import ( + preprocess_scale, +) +from torchao.dtypes.utils import get_out_shape +from torchao.float8.inference import ( + Float8MMConfig, + addmm_float8_unwrapped_inference, + preprocess_data, +) +from torchao.quantization.granularity import PerRow +from torchao.quantization.observer import get_block_size +from torchao.quantization.quant_primitives import ( + _choose_qparams_affine_float8, + _quantize_affine_float8, +) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, @@ -26,13 +41,34 @@ class FbgemmFp8Tensor(TorchAOBaseTensor): """ + Float8 Rowwise Quantized (weight) Tensor, with float8 rowwise dynamic quantization for activation. TODO: needs padding for cutlass kernels + + Tensor Attributes: + float8_data: float8 raw data + scale: the rowwise scale for float8 Tensor + activation_scale_ub: upper bound for activation scale, used during dynamic quantization for activation + + Non-Tensor Attributes: + rowwise_dim (int): the dimension for rowwise quantization, initially it's -1, but might change when we + transpose the Tensor + dtype: Original Tensor dtype """ tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"] - tensor_attributes = ["dtype"] + tensor_attributes = ["rowwise_dim", "mm_config", "kernel", "dtype"] + _SUPPORTED_KERNELS = ["fbgemm", "aten"] - def __new__(cls, float8_data, scale, activation_scale_ub, dtype): + def __new__( + cls, + float8_data, + scale, + activation_scale_ub, + rowwise_dim, + mm_config, + kernel, + dtype, + ): shape = float8_data.shape kwargs = {} kwargs["device"] = float8_data.device @@ -40,10 +76,22 @@ def __new__(cls, float8_data, scale, activation_scale_ub, dtype): kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, float8_data, scale, activation_scale_ub, dtype): + def __init__( + self, + float8_data, + scale, + activation_scale_ub, + rowwise_dim, + mm_config, + kernel, + dtype, + ): self.float8_data = float8_data self.scale = scale self.activation_scale_ub = activation_scale_ub + self.rowwise_dim = rowwise_dim % self.float8_data.ndim + self.mm_config = mm_config + self.kernel = kernel def __tensor_flatten__(self): return self.tensor_data_attrs, [ @@ -68,12 +116,13 @@ def _apply_fn_to_data(self, fn): def __repr__(self): return ( f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, " - f"activation_scale_ub={self.activation_scale_ub}, " - f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + f"activation_scale_ub={self.activation_scale_ub}, rowwise_dim={self.rowwise_dim}, " + f"mm_config={self.mm_config}, kernel={self.kernel}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): - return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}" + return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, rowwise_dim={self.rowwise_dim}, mm_config={self.mm_config}, kernel={self.kernel}, device={self.device}" def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -82,6 +131,76 @@ def to(self, *args, **kwargs): self.float8_data.to(device), self.scale.to(device), self.activation_scale_ub.to(device), + self.rowwise_dim, + self.mm_config, + self.kernel, + self.dtype, + ) + + def _transpose_and_reshape(self): + """This is added for resharding support, since the resharding logic for the model we are + working with only support 2D + + High level goal is to match the shape of the original unquantized Tensor and reshape + it to 2D since resharding logic only supports 2D Tensor + + * transpose(1, 2) since we did a transpose initially to quantize the weight + * reshape to 2D + """ + assert len(self.shape) == 3, ( + f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}" + ) + dim0, dim1, dim2 = self.shape + # because we first transpose the weight before quantization, we'll recover the original shape + # by swapping dim1 and dim2 + original_shape = (dim0, dim2, dim1) + # we must save this as 2D in the state dict, since loading code expects 2D weights + new_shape = (-1, original_shape[-1]) + float8_data = self.float8_data + float8_data = float8_data.transpose(1, 2).reshape(*new_shape).contiguous() + scale = self.scale.transpose(1, 2).reshape(*new_shape).contiguous() + if self.rowwise_dim in [0, 2]: + rowwise_dim = 0 + else: + rowwise_dim = 1 + + return self.__class__( + float8_data, + scale, + self.activation_scale_ub, + rowwise_dim, + self.mm_config, + self.kernel, + self.dtype, + ) + + def _unflatten(self, num_experts): + """This is added for resharding support, since the resharding logic for the model we are + working with only support 2D + + This is called after resharding logic, and it reverses the reshape to 2D in `_transpose_and_reshape` + and gives a 3D tensor with `num_experts` as the first dimension + """ + float8_data = self.float8_data + scale = self.scale + float8_data = float8_data.unflatten(0, (num_experts, -1)).squeeze(dim=0) + scale = scale.unflatten(0, (num_experts, -1)).squeeze(dim=0) + if self.rowwise_dim == 0: + rowwise_dim = 1 + else: + rowwise_dim = 2 + + for d in range(len(float8_data.shape)): + if float8_data.shape[d] != scale.shape[d] and scale.shape[d] == 1: + rowwise_dim = d + + return self.__class__( + float8_data, + scale, + self.activation_scale_ub, + rowwise_dim, + self.mm_config, + self.kernel, self.dtype, ) @@ -89,9 +208,17 @@ def to(self, *args, **kwargs): def from_float( cls, w: torch.Tensor, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, activation_scale_ub: Optional[float] = None, - transpose_input: bool = False, + mm_config: Optional[Float8MMConfig] = None, + kernel: str = "fbgemm", ): + assert kernel in cls._SUPPORTED_KERNELS + + assert input_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + assert weight_dtype == input_dtype + if activation_scale_ub is None: activation_scale_ub = 1200.0 @@ -100,22 +227,41 @@ def from_float( dtype=torch.float, device=w.device, ) - if transpose_input: - if w.ndim == 3: - w = w.transpose(-1, -2) - else: - w = w.t() - - wq, w_scale = torch.ops.triton.quantize_fp8_row(w) - # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - dtype = w.dtype - del w - return FbgemmFp8Tensor( - wq, - w_scale, - activation_scale_ub=activation_scale_ub, - dtype=dtype, - ) + + if kernel == "fbgemm": + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + # add a last dimension for per row quantization to align the rank of + # w_scale and wq + w_scale = w_scale.unsqueeze(-1).contiguous() + dtype = w.dtype + del w + return FbgemmFp8Tensor( + wq, + w_scale, + activation_scale_ub=activation_scale_ub, + rowwise_dim=wq.ndim - 1, + mm_config=mm_config, + kernel=kernel, + dtype=dtype, + ) + + else: + block_size = get_block_size(w.shape, PerRow()) + w_scale = _choose_qparams_affine_float8( + w, float8_dtype=weight_dtype, block_size=block_size + ) + wq = _quantize_affine_float8(w, w_scale, weight_dtype) + dtype = w.dtype + del w + return FbgemmFp8Tensor( + wq, + w_scale, + activation_scale_ub=activation_scale_ub, + rowwise_dim=wq.ndim - 1, + mm_config=mm_config, + kernel=kernel, + dtype=dtype, + ) implements = FbgemmFp8Tensor.implements @@ -131,27 +277,59 @@ def _(func, types, args, kwargs): orig_act_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] - # not used - num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - input_tensor, num_tokens, weight_tensor.activation_scale_ub - ) + if weight_tensor.kernel == "fbgemm": + # not used + num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) + a_data, a_scale = torch.ops.fbgemm.quantize_fp8_per_row( + input_tensor, num_tokens, weight_tensor.activation_scale_ub + ) - a_data = xq - b_data = weight_tensor.float8_data + b_data = weight_tensor.float8_data + b_scale = weight_tensor.scale.squeeze(-1) - res = torch.ops.fbgemm.f8f8bf16_rowwise( - a_data, - b_data, - x_scale, - weight_tensor.scale, - use_fast_accum=True, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: - res = res + bias + res = torch.ops.fbgemm.f8f8bf16_rowwise( + a_data, + b_data, + a_scale, + b_scale, + use_fast_accum=True, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + if bias is not None: + res = res + bias + return res + else: + scaled_mm_config = weight_tensor.mm_config + assert scaled_mm_config is not None + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + block_size = get_block_size(input_tensor.shape, PerRow()) + weight_dtype = weight_tensor.float8_data.dtype + # Note: we assume input dtype is the same as weight dtype + input_scale = _choose_qparams_affine_float8( + input_tensor, float8_dtype=weight_dtype, block_size=block_size + ) + inpt_data = _quantize_affine_float8(input_tensor, input_scale, weight_dtype) - return res + # Extract tensor data and scales + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + w_data = weight_tensor.float8_data + w_scale = weight_tensor.scale + w_scale = w_scale.transpose(-1, -2) + + input_scale = preprocess_scale(input_scale, input_tensor.shape) + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) @implements(torch.bmm) @@ -161,21 +339,27 @@ def _(func, types, args, kwargs): args[1], ) orig_act_size = input_tensor.size() + assert weight_tensor.kernel == "fbgemm", ( + f"Only fbgemm kernel support bmm right now, got {weight_tensor.kernel}" + ) + # not used num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + a_data, a_scale = torch.ops.fbgemm.quantize_fp8_per_row( input_tensor, num_tokens, weight_tensor.activation_scale_ub ) - a_data = xq b_data = weight_tensor.float8_data + b_scale = weight_tensor.scale.squeeze(-1) + assert b_data.is_contiguous(), "weight for bmm must be contiguous" + orig_out_features = b_data.shape[-2] res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, b_data, - x_scale, - weight_tensor.scale, + a_scale, + b_scale, ) res = res.reshape(*orig_act_size[:-1], orig_out_features) return res @@ -203,6 +387,7 @@ def _same_metadata(self: "FbgemmFp8Tensor", src: "FbgemmFp8Tensor") -> bool: and self.float8_data.shape == src.float8_data.shape and self.scale.shape == src.scale.shape and self.activation_scale_ub.shape == src.activation_scale_ub.shape + and self.rowwise_dim == src.rowwise_dim and self.dtype == src.dtype ) @@ -250,12 +435,12 @@ def _(func, types, args, kwargs): self.float8_data, dim, start, end, step ).contiguous() - if dim == 0: + if dim != self.rowwise_dim: # scale has dimension (N,) where N is the dim 0 of `self` # so we do the same slice on scale for dimension 0 sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step) else: - # since scale is per row, slicing along the dim == 1 dimension does + # since scale is per row, slicing along the rowwise dimension does # not change the scale sliced_scale = self.scale @@ -264,11 +449,88 @@ def _(func, types, args, kwargs): args, kwargs, FbgemmFp8Tensor( - sliced_data, sliced_scale, self.activation_scale_ub, dtype=self.dtype + sliced_data, + sliced_scale, + self.activation_scale_ub, + self.rowwise_dim, + self.mm_config, + self.kernel, + dtype=self.dtype, ), ) +@implements(aten.cat.default) +def _(func, types, args, kwargs): + """Concatenate multiple float8 quantized tensors + (scale and float8_data has the same rank) + If the concatenation dimension is not the same as rowwise_dim, then we can just concatenate the + float8_data and scale directly + If the concatention dimension is the same as rowwise_dim, theoretically we should either + (1) check that scales from all tensors are equal and use the first scale + (2) dequantize and requantize + but for now we just use the first scale directly, which might have slight implication on accuaracy + we can improve upon this a bit later + """ + + tensors, dim = fill_defaults(args, 2, [[], 0]) + tensor_0 = tensors[0] + dim = dim % tensor_0.ndim + + # assert dim != tensor_0.rowwise_dim, f"Doesn't support concatenation over rowwise dimension: {dim=} {tensor_0.float8_data.shape=}, {tensor_0.rowwise_dim=} {tensor_0.scale.shape=}" + + for i in range(1, len(tensors)): + assert tensor_0.float8_data.ndim == tensors[i].float8_data.ndim + assert tensor_0.scale.ndim == tensors[i].scale.ndim + assert tensor_0.activation_scale_ub == tensors[i].activation_scale_ub + assert tensor_0.rowwise_dim == tensors[i].rowwise_dim + + float8_datas = [t.float8_data for t in tensors] + scales = [t.scale for t in tensors] + + cat_float8_data = aten.cat.default(float8_datas, dim=dim) + if dim != tensor_0.rowwise_dim: + cat_scale = aten.cat.default(scales, dim=dim) + else: + # TODO: this is not exactly correct, we'll need to + # figure out how to do this properly in the future + cat_scale = scales[0] + + new = tensor_0.__class__( + cat_float8_data, + cat_scale, + tensor_0.activation_scale_ub, + tensor_0.rowwise_dim, + tensor_0.mm_config, + tensor_0.kernel, + tensor_0.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + float8_data = self.float8_data.transpose(dim0, dim1).contiguous() + scale = self.scale.transpose(dim0, dim1).contiguous() + + if self.rowwise_dim == dim0: + rowwise_dim = dim1 + elif self.rowwise_dim == dim1: + rowwise_dim = dim0 + + new = self.__class__( + float8_data, + scale, + self.activation_scale_ub, + rowwise_dim, + self.mm_config, + self.kernel, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + to_fbgemm_fp8 = FbgemmFp8Tensor.from_float diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index 0c00ee1a81..9a8a9bc668 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -32,20 +32,43 @@ class FbgemmInt4Tensor(TorchAOBaseTensor): + """ + Groupwise int4 weight only quantization + + Tensor Attributes: + packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed + scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor + dtype is the same as the original Tensor dtype + zero_point: Same size as the scale + dtype is the same as the original Tensor dtype + + Non-Tensor Attributes: + group_size: the group size for groupwise quantization + shape_multiplier: is the multipler from packed_weight to the real weight, since + we pack the weight for int4, for example, when we pack the last dimension for + a 2D tensor, the shape_multiplier will be [1, 2] + shape: shape of the original Tensor + """ + tensor_data_attrs = ["packed_weight", "scale", "zero_point"] - tensor_attributes = ["group_size", "shape"] + tensor_attributes = ["group_size", "shape_multiplier", "shape"] - def __new__(cls, packed_weight, scale, zero_point, group_size, shape): + def __new__( + cls, packed_weight, scale, zero_point, group_size, shape_multiplier, shape + ): kwargs = {} kwargs["device"] = packed_weight.device kwargs["dtype"] = scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, packed_weight, scale, zero_point, group_size, shape): + def __init__( + self, packed_weight, scale, zero_point, group_size, shape_multiplier, shape + ): self.packed_weight = packed_weight self.scale = scale self.zero_point = zero_point + self.shape_multiplier = shape_multiplier self.group_size = group_size def __tensor_flatten__(self): @@ -71,7 +94,8 @@ def _apply_fn_to_data(self, fn): def __repr__(self): return ( f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " - f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + f"shape_multiplier={self.shape_multiplier}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): @@ -85,15 +109,89 @@ def to(self, *args, **kwargs): self.scale.to(device), self.zero_point.to(device), self.group_size, + self.shape_multiplier, self.shape, ) + def _transpose_and_reshape(self): + """This is added for resharding support, since the resharding logic for the model we are + working with only support 2D + + High level goal is to match the shape of the original unquantized Tensor and reshape + it to 2D since resharding logic only supports 2D Tensor + + * transpose(1, 2) since we did a transpose initially to quantize the weight + * reshape to 2D + """ + assert len(self.shape) == 3, ( + f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}" + ) + dim0, dim1, dim2 = self.shape + shape_multiplier = self.shape_multiplier.copy() + # because we first transpose the weight before quantization, we'll recover the original shape + # by swapping dim1 and dim2 + assert shape_multiplier[-1] == 2, ( + "Expecting original weight to be packed in the last dimension" + ) + original_shape = (dim0, dim2, dim1) + # we must save this as 2D in the state dict, since loading code expects 2D weights + new_shape = (-1, original_shape[-1]) + packed_weight = self.packed_weight + packed_weight = packed_weight.transpose(1, 2).reshape(*new_shape).contiguous() + # expecting the packed dimension to be swapped to 0 + shape_multiplier = [2, 1] + + tensor_shape = list(packed_weight.shape) + for i in range(len(shape_multiplier)): + tensor_shape[i] *= shape_multiplier[i] + tensor_shape = tuple(tensor_shape) + return self.__class__( + packed_weight, + self.scale, + self.zero_point, + self.group_size, + shape_multiplier, + tensor_shape, + ) + + def _unflatten(self, num_experts): + """This is added for resharding support, since the resharding logic for the model we are + working with only support 2D + + This is called after resharding logic, and it reverses the reshape to 2D in `_transpose_and_reshape` + and gives a 3D tensor with `num_experts` as the first dimension + """ + packed_weight = self.packed_weight + shape_multiplier = self.shape_multiplier + dim0, dim1 = self.shape + packed_weight = packed_weight.unflatten(0, (num_experts, -1)).squeeze(dim=0) + if shape_multiplier == [2, 1]: + shape_multiplier = [1, 2, 1] + elif shape_multiplier == [1, 2]: + shape_multiplier = [1, 1, 2] + else: + raise NotImplementedError( + f"Unexpected shape multiplier: {shape_multiplier}" + ) + + tensor_shape = list(packed_weight.shape) + for i in range(len(shape_multiplier)): + tensor_shape[i] *= shape_multiplier[i] + tensor_shape = tuple(tensor_shape) + return self.__class__( + packed_weight, + self.scale, + self.zero_point, + self.group_size, + shape_multiplier, + tensor_shape, + ) + @classmethod def from_float( cls, w: torch.Tensor, block_size: List[int], - transpose_input: bool = False, ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" @@ -101,12 +199,6 @@ def from_float( if int4_row_quantize_zp is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - if transpose_input: - if w.ndim == 3: - w = w.transpose(-1, -2) - else: - w = w.t() - group_size = block_size[-1] original_shape = w.shape @@ -124,12 +216,16 @@ def from_float( scale = scale.to(w.dtype) zero_point = zero_point.to(w.dtype) + shape_multiplier = [1] * wq.ndim + shape_multiplier[-1] = 2 + del w return FbgemmInt4Tensor( packed_weight=wq, scale=scale, zero_point=zero_point, group_size=group_size, + shape_multiplier=shape_multiplier, shape=original_shape, ) @@ -144,16 +240,21 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - orig_act_size = input_tensor.size() + orig_input_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] + assert weight_tensor.shape_multiplier[-1] == 2, ( + "Expecting the last dimension of weight to be the packed dimension" + ) + input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]) res = torch.ops.fbgemm.bf16i4bf16_rowwise( input_tensor, weight_tensor.packed_weight.contiguous(), weight_tensor.scale, weight_tensor.zero_point, ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) + + res = res.reshape(*orig_input_size[:-1], orig_out_features) if bias is not None: res = res + bias return res @@ -165,8 +266,9 @@ def _(func, types, args, kwargs): args[0], args[1], ) - orig_act_size = input_tensor.size() + orig_input_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] + assert weight_tensor.shape_multiplier[-1] == 2 res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( input_tensor, @@ -174,7 +276,7 @@ def _(func, types, args, kwargs): weight_tensor.scale, weight_tensor.zero_point, ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) + res = res.reshape(*orig_input_size[:-1], orig_out_features) return res @@ -201,6 +303,7 @@ def _same_metadata(self: "FbgemmInt4Tensor", src: "FbgemmInt4Tensor") -> bool: and self.scale.shape == src.scale.shape and self.zero_point.shape == src.zero_point.shape and self.group_size == src.group_size + and self.shape_multiplier == src.shape_multiplier ) @@ -222,7 +325,7 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): """Only supports slicing for dim == 1 and dim == 2 packed_weight has dimension: (N, K/2) - scale and zero_point has dimension: (K/groups, N) + scale and zero_point has dimension: (K/group_size, N) dim, start, end, step are args that's referring to the original tensor shape which is (N, K), and we need to map that to the transformed weight shape of packed_weight, @@ -289,7 +392,80 @@ def _(func, types, args, kwargs): packed_shape0, packed_shape1 = packed_weight.shape new_shape = (packed_shape0, packed_shape1 * 2) new = self.__class__( - packed_weight, scale, zero_point, group_size=self.group_size, shape=new_shape + packed_weight, + scale, + zero_point, + group_size=self.group_size, + shape_multiplier=self.shape_multiplier, + shape=new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.cat.default) +def _(func, types, args, kwargs): + tensors, dim = fill_defaults(args, 2, [[], 0]) + tensor_0 = tensors[0] + if dim < 0: + dim = dim + tensor_0.ndim + + for i in range(1, len(tensors)): + assert tensor_0.packed_weight.ndim == tensors[i].packed_weight.ndim + assert tensor_0.scale.ndim == tensors[i].scale.ndim + assert tensor_0.zero_point.ndim == tensors[i].zero_point.ndim + assert tensor_0.group_size == tensors[i].group_size + assert tensor_0.shape_multiplier == tensors[i].shape_multiplier + + packed_weight = [t.packed_weight for t in tensors] + scale = [t.scale for t in tensors] + zero_point = [t.zero_point for t in tensors] + + # with group wise quantization, dimension of scale, packed_weight and + # origianl shape will be the same, so original dim argument applies + # to both packed_weight and scale + cat_packed_weight = aten.cat.default(packed_weight, dim) + if cat_packed_weight.ndim == 2: + sz_dim = 1 - dim + else: + sz_dim = dim + cat_scale = aten.cat.default(scale, sz_dim) + cat_zero_point = aten.cat.default(zero_point, sz_dim) + new_shape = list(cat_packed_weight.shape) + for i in range(len(tensor_0.shape_multiplier)): + new_shape[i] *= tensor_0.shape_multiplier[i] + new_shape = tuple(new_shape) + new = tensor_0.__class__( + cat_packed_weight, + cat_scale, + cat_zero_point, + group_size=tensor_0.group_size, + shape_multiplier=tensor_0.shape_multiplier, + shape=new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + packed_weight = self.packed_weight.transpose(dim0, dim1).contiguous() + shape_multiplier = self.shape_multiplier.copy() + shape_multiplier[dim0], shape_multiplier[dim1] = ( + shape_multiplier[dim1], + shape_multiplier[dim0], + ) + + tensor_shape = list(packed_weight.shape) + for i in range(len(shape_multiplier)): + tensor_shape[i] *= shape_multiplier[i] + tensor_shape = tuple(tensor_shape) + new = self.__class__( + packed_weight, + self.scale, + self.zero_point, + self.group_size, + shape_multiplier, + tensor_shape, ) return return_and_correct_aliasing(func, args, kwargs, new) diff --git a/torchao/dtypes/int4_groupwise_preshuffle_tensor.py b/torchao/dtypes/int4_groupwise_preshuffle_tensor.py new file mode 100644 index 0000000000..a972f94fe9 --- /dev/null +++ b/torchao/dtypes/int4_groupwise_preshuffle_tensor.py @@ -0,0 +1,406 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib.util +from typing import List + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, + fill_defaults, +) + +__all__ = [ + "to_int4_groupwise_preshuffle", + "Int4GroupwisePreshuffleTensor", +] + +aten = torch.ops.aten + + +if importlib.util.find_spec("fbgemm_gpu") is None: + quantize_int4_preshuffle = None +else: + from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle + + +class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): + """ + Args: + shape_multiplier: is the multipler from packed_weight to the real weight, since + we pack the weight for int4, for example, when we pack the last dimension for + a 2D tensor, the shape_multiplier will be [1, 2] + """ + + tensor_data_attrs = ["packed_weight", "group_scale", "row_scale"] + tensor_attributes = ["group_size", "shape_multiplier", "shape"] + + def __new__( + cls, packed_weight, group_scale, row_scale, group_size, shape_multiplier, shape + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["dtype"] = group_scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, packed_weight, group_scale, row_scale, group_size, shape_multiplier, shape + ): + self.packed_weight = packed_weight + self.group_scale = group_scale + self.row_scale = row_scale + self.shape_multiplier = shape_multiplier + self.group_size = group_size + + def __tensor_flatten__(self): + return self.tensor_data_attrs, [ + getattr(self, attr) for attr in self.tensor_attributes + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensor_attributes, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *[getattr(self, attr) for attr in self.tensor_attributes], + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " + f"shape_multiplier={self.shape_multiplier}, shape={self.shape}, device={self.device}, dtype={self.dtype}, " + f"requires_grad={self.requires_grad})" + ) + + def _quantization_type(self): + return f"shape={self.shape}, group_size={self.group_size}, device={self.device}" + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.packed_weight.to(device), + self.group_scale.to(device), + self.row_scale.to(device), + self.group_size, + self.shape_multiplier, + self.shape, + ) + + @classmethod + def from_float( + cls, + w: torch.Tensor, + block_size: List[int], + ): + assert len(block_size) == w.ndim, ( + f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" + ) + if quantize_int4_preshuffle is None: + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + + group_size = block_size[-1] + original_shape = w.shape + + if w.ndim >= 3: + wq, scales = zip( + *[quantize_int4_preshuffle(i.cuda(), dtype="bf16") for i in w] + ) + wq = torch.stack(wq, dim=0) + group_scale, row_scale = zip(*scales) + row_scale = torch.stack(row_scale, dim=0) + group_scale = torch.stack(group_scale, dim=0) + else: + wq, (group_scale, row_scale) = quantize_int4_preshuffle( + w.cuda(), dtype="bf16" + ) + + shape_multiplier = [1] * wq.ndim + shape_multiplier[-1] = 2 + + del w + return Int4GroupwisePreshuffleTensor( + packed_weight=wq, + group_scale=group_scale, + row_scale=row_scale, + group_size=group_size, + shape_multiplier=shape_multiplier, + shape=original_shape, + ) + + +implements = Int4GroupwisePreshuffleTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + orig_input_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + wq = weight_tensor.packed_weight + group_scale = weight_tensor.group_scale + row_scale = weight_tensor.row_scale + + if input_tensor.dim() == 3: + B, M, _ = input_tensor.shape + _, N, _ = wq.shape + res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) + for i in range(B): + res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( + input_tensor[i], wq[i], group_scale[i], row_scale[i] + ) + else: + # Otherwise run gemm normally. + res = torch.ops.fbgemm.bf16i4bf16_shuffled( + input_tensor, wq, group_scale, row_scale + ) + + res = res.reshape(*orig_input_size[:-1], orig_out_features) + if bias is not None: + res = res + bias + return res + + +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + orig_input_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + assert weight_tensor.shape_multiplier[-1] == 2 + + wq = weight_tensor.packed_weight + group_scale = weight_tensor.group_scale + row_scale = weight_tensor.row_scale + B, M, _ = input_tensor.shape + _, N, _ = wq.shape + res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) + for i in range(B): + res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( + input_tensor[i], wq[i], group_scale[i], row_scale[i] + ) + res = res.reshape(*orig_input_size[:-1], orig_out_features) + return res + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +def _same_metadata( + self: "Int4GroupwisePreshuffleTensor", src: "Int4GroupwisePreshuffleTensor" +) -> bool: + return ( + isinstance(self, Int4GroupwisePreshuffleTensor) + and isinstance(src, Int4GroupwisePreshuffleTensor) + and self.shape == src.shape + and self.packed_weight.shape == src.packed_weight.shape + and self.group_scale.shape == src.group_scale.shape + and self.row_scale.shape == src.row_scale.shape + and self.group_size == src.group_size + and self.shape_multiplier == src.shape_multiplier + ) + + +@implements(aten.copy_.default) +def _(func, types, args, kwargs): + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Only supports slicing for dim == 1 and dim == 2 + packed_weight has dimension: (N, K/2) + group_scale and row_scale has dimension: (K/groups, N) + + dim, start, end, step are args that's referring to the original tensor shape + which is (N, K), and we need to map that to the transformed weight shape of packed_weight, + group_scale and row_scale + + when dim == 0: we do a slice on packed_weight dim 0, and on dim 1 of group_scale and row_scale, + also adjust the start and end indexes based on the ratio between original shape and the shape + of packed_weight and group_scale/row_scale + + when dim == 1: we do a slice on packed_weight dim 1 and dim 0 of group_scale and row_scale and do the + same adjustment based on ratio + + Note that we need to call slice on the packed_weight, group_scale and row_scale directly because slice + is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_int4` + for + """ + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + + assert self.packed_weight.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self.packed_weight.dim}" + ) + N, K_by_2 = self.packed_weight.shape + sz_dim0, sz_dim1 = self.group_scale.shape + + data_len = self.shape[dim] + + if dim == 0: + pw_len = N + sz_len = sz_dim1 + else: + pw_len = K_by_2 + sz_len = sz_dim0 + + sz_dim = 1 - dim + if pw_len == 0 or sz_len == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + self.__class__( + self.packed_weight, + self.group_scale, + self.row_scale, + group_size=self.group_size, + shape=self.shape, + ), + ) + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + packed_weight = aten.slice.Tensor(self.packed_weight, dim, start_pw, end_pw, step) + group_scale = aten.slice.Tensor(self.group_scale, sz_dim, start_sz, end_sz, step) + row_scale = aten.slice.Tensor(self.row_scale, sz_dim, start_sz, end_sz, step) + packed_shape0, packed_shape1 = packed_weight.shape + new_shape = (packed_shape0, packed_shape1 * 2) + new = self.__class__( + packed_weight, + group_scale, + row_scale, + group_size=self.group_size, + shape_multiplier=self.shape_multiplier, + shape=new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.cat.default) +def _(func, types, args, kwargs): + tensors, dim = fill_defaults(args, 2, [[], 0]) + tensor_0 = tensors[0] + if dim < 0: + dim = dim + tensor_0.ndim + + for i in range(1, len(tensors)): + assert tensor_0.packed_weight.ndim == tensors[i].packed_weight.ndim + assert tensor_0.group_scale.ndim == tensors[i].group_scale.ndim + assert tensor_0.row_scale.ndim == tensors[i].row_scale.ndim + assert tensor_0.group_size == tensors[i].group_size + assert tensor_0.shape_multiplier == tensors[i].shape_multiplier + + packed_weight = [t.packed_weight for t in tensors] + group_scale = [t.group_scale for t in tensors] + row_scale = [t.row_scale for t in tensors] + + # with group wise quantization, dimension of group_scale, packed_weight and + # origianl shape will be the same, so original dim argument applies + # to both packed_weight and group_scale + cat_packed_weight = aten.cat.default(packed_weight, dim) + if cat_packed_weight.ndim == 2: + sz_dim = 1 - dim + else: + sz_dim = dim + + cat_group_scale = aten.cat.default(group_scale, sz_dim) + cat_row_scale = aten.cat.default(row_scale, sz_dim) + new_shape = list(cat_packed_weight.shape) + for i in range(len(tensor_0.shape_multiplier)): + new_shape[i] *= tensor_0.shape_multiplier[i] + new_shape = tuple(new_shape) + new = tensor_0.__class__( + cat_packed_weight, + cat_group_scale, + cat_row_scale, + group_size=tensor_0.group_size, + shape_multiplier=tensor_0.shape_multiplier, + shape=new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + packed_weight = self.packed_weight.transpose(dim0, dim1).contiguous() + shape_multiplier = self.shape_multiplier.copy() + shape_multiplier[dim0], shape_multiplier[dim1] = ( + shape_multiplier[dim1], + shape_multiplier[dim0], + ) + + tensor_shape = list(packed_weight.shape) + for i in range(len(shape_multiplier)): + tensor_shape[i] *= shape_multiplier[i] + tensor_shape = tuple(tensor_shape) + new = self.__class__( + packed_weight, + self.group_scale, + self.row_scale, + self.group_size, + shape_multiplier, + tensor_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +to_int4_groupwise_preshuffle = Int4GroupwisePreshuffleTensor.from_float + + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with Int4GroupwisePreshuffleTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([Int4GroupwisePreshuffleTensor]) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 939f68e59a..f8ebaa7981 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -53,7 +53,6 @@ def short_str(self): class Float8TypeConfig: """ Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz. - Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4. """ diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7b40f388ed..24e3d72c5e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -48,6 +48,7 @@ to_affine_quantized_intx, to_fbgemm_fp8, to_fbgemm_int4, + to_int4_groupwise_preshuffle, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -1545,8 +1546,12 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None + activation_scale_ub: Optional[float] = None + kernel: str = "aten" set_inductor_config: bool = True + _SUPPORTED_KERNELS = ["aten", "fbgemm"] + def __post_init__(self): if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) @@ -1555,6 +1560,7 @@ def __post_init__(self): self.granularity ) self.granularity = [activation_granularity, weight_granularity] + assert self.kernel in self._SUPPORTED_KERNELS # for bc @@ -1566,6 +1572,8 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config + kernel = config.kernel + activation_scale_ub = config.activation_scale_ub # Ensure works on device _check_hardware_support(granularity) @@ -1583,23 +1591,37 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): block_size = get_block_size(weight.shape[-2:], weight_granularity) if weight.dim() == 3: block_size = tuple([1] + list(block_size)) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } + if isinstance(activation_granularity, PerRow) and isinstance( + weight_granularity, PerRow + ): + quantized_weight = to_fbgemm_fp8( + weight, + activation_dtype, + weight_dtype, + activation_scale_ub, + mm_config, + kernel, + ) + else: + # Note: kernel is not used for non per row quantization case yet + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) return quantized_weight @@ -1998,7 +2020,9 @@ class FbgemmConfig(AOBaseConfig): output_dtype: torch.dtype block_size: Optional[List[int]] = None activation_scale_ub: Optional[float] = None - transpose_input: bool = False + preshuffle: bool = False + mm_config = Float8MMConfig(use_fast_accum=True) + kernel: str = "fbgemm" @register_quantize_module_handler(FbgemmConfig) @@ -2018,28 +2042,36 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: (e4m3_dtype, e4m3_dtype, torch.bfloat16), } + _FP8_DTYPES = { + # CUDA + (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16), + # AMD + (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, torch.bfloat16), + } + if ( (config.input_dtype == torch.bfloat16) and (config.weight_dtype == torch.int4) and (config.output_dtype == torch.bfloat16) ): - weight = to_fbgemm_int4( - module.weight, - config.block_size, - config.transpose_input, - ) + if config.preshuffle: + weight = to_int4_groupwise_preshuffle(module.weight, config.block_size) + else: + weight = to_fbgemm_int4( + module.weight, + config.block_size, + ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module - elif ( - (config.input_dtype == e4m3_dtype) - and (config.weight_dtype == e4m3_dtype) - and (config.output_dtype == torch.bfloat16) - ): + elif (config.input_dtype, config.weight_dtype, config.output_dtype) in _FP8_DTYPES: weight = to_fbgemm_fp8( module.weight, + config.input_dtype, + config.weight_dtype, config.activation_scale_ub, - config.transpose_input, + config.mm_config, + config.kernel, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module)