Skip to content

Add support for resharding and int4 preshuffle kernel #2387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 134 additions & 20 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils

from torchao.dtypes import FbgemmFp8Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
from torchao.float8.float8_utils import compute_error
from torchao.quantization import (
Expand Down Expand Up @@ -324,19 +325,15 @@ def test_mm_float8dq_per_row(

quant_weight = test_linear.weight

self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
weight_impl = quant_weight.original_weight_tensor.tensor_impl

self.assertTrue(hasattr(weight_impl, "float8_data"))
self.assertTrue(hasattr(weight_impl, "scale"))
self.assertFalse(weight_impl.transposed)
self.assertTrue(hasattr(quant_weight, "float8_data"))
self.assertTrue(hasattr(quant_weight, "scale"))

# Verify scale shape for row-wise quantization
expected_scale_shape = (out_features, 1)
actual_scale_shape = weight_impl.scale.shape
actual_scale_shape = quant_weight.scale.shape
self.assertEqual(actual_scale_shape, expected_scale_shape)

self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
self.assertEqual(quant_weight.float8_data.shape, (out_features, in_features))

input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)

Expand Down Expand Up @@ -419,11 +416,11 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_float8_tensor_slicing_basic(self, granularity):
def test_float8_tensor_slicing_basic_per_tensor(self):
"""Test basic slicing operations on Float8 tensors"""
device = "cuda"
dtype = torch.bfloat16
granularity = PerTensor()

# Create and quantize a model
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
Expand All @@ -450,6 +447,41 @@ def test_float8_tensor_slicing_basic(self, granularity):
self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl))
self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl))

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_float8_tensor_slicing_basic_per_row(self):
"""Test basic slicing operations on Float8 tensors"""
device = "cuda"
dtype = torch.bfloat16
granularity = PerRow()

# Create and quantize a model
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
)

weight = model.weight

# Test dimension 0 slicing (rows)
sliced_0 = weight[10:20]
self.assertEqual(sliced_0.shape, (10, 64))

# Test dimension 1 slicing (columns)
sliced_1 = weight[:, 20:40]
self.assertEqual(sliced_1.shape, (32, 20))

# Test combined slicing
sliced_both = weight[5:15, 10:30]
self.assertEqual(sliced_both.shape, (10, 20))

# Verify the sliced tensors are still Float8 tensors
self.assertTrue(isinstance(sliced_0, FbgemmFp8Tensor))
self.assertTrue(isinstance(sliced_1, FbgemmFp8Tensor))
self.assertTrue(isinstance(sliced_both, FbgemmFp8Tensor))

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
Expand Down Expand Up @@ -497,27 +529,26 @@ def test_float8_tensor_slicing_per_row(self):
)

original_weight = model.weight # Shape: (32, 64)
original_impl = original_weight.original_weight_tensor.tensor_impl
original_scale = original_impl.scale # Shape: (32, 1)
original_scale = model.weight.scale # Shape: (32, 1)

# Test row slicing (dimension 0)
sliced_rows = original_weight[10:20] # Shape: (10, 64)
sliced_impl = sliced_rows.original_weight_tensor.tensor_impl
sliced_scale = sliced_rows.scale

# Scale should be sliced to match the rows
expected_scale_shape = (10, 1)
self.assertEqual(sliced_impl.scale.shape, expected_scale_shape)
self.assertEqual(sliced_scale.shape, expected_scale_shape)

# Verify the scale values are correct (should be subset of original)
self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20]))
self.assertTrue(torch.equal(sliced_scale, original_scale[10:20]))

# Test column slicing (dimension 1) - scale should not change for per-row
sliced_cols = original_weight[:, 20:40] # Shape: (32, 20)
sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl
sliced_cols_scale = sliced_cols.scale

# Scale shape should remain the same since we're not changing rows
self.assertEqual(sliced_cols_impl.scale.shape, (32, 1))
self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale))
self.assertEqual(sliced_cols_scale.shape, (32, 1))
self.assertTrue(torch.equal(sliced_cols_scale, original_scale))

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
Expand Down Expand Up @@ -552,15 +583,15 @@ def test_float8_tensor_slicing_edge_cases(self):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@unittest.skipIf(
is_sm_version(8, 9),
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
)
def test_float8_tensor_slicing_functional_correctness(self, granularity):
def test_float8_tensor_slicing_functional_correctness_per_tensor(self):
"""Test that sliced tensors produce correct results in computations"""
device = "cuda"
dtype = torch.bfloat16
granularity = PerTensor()

# Create reference and quantized models with dimensions that are multiples of 16
ref_model = (
Expand Down Expand Up @@ -630,6 +661,89 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
error = compute_error(ref_output, quant_output)
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@unittest.skipIf(
is_sm_version(8, 9),
"TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15",
)
def test_float8_tensor_slicing_functional_correctness_per_row(self):
"""Test that sliced tensors produce correct results in computations"""
device = "cuda"
dtype = torch.bfloat16
granularity = PerRow()

# Create reference and quantized models with dimensions that are multiples of 16
ref_model = (
torch.nn.Linear(64, 48, bias=False).to(device).to(dtype)
) # 48 is divisible by 16
quant_model = copy.deepcopy(ref_model)
quantize_(
quant_model,
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
)

# Create input with batch size that works well with slicing
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)

ref_weight_slice = ref_model.weight[0:16, 0:32]
quant_weight_slice = quant_model.weight[0:16, 0:32]

# Verify that the sliced weights maintain Float8 properties
self.assertTrue(hasattr(quant_weight_slice, "float8_data"))
self.assertTrue(hasattr(quant_weight_slice, "scale"))
sliced_impl = quant_weight_slice
self.assertTrue(isinstance(sliced_impl, FbgemmFp8Tensor))

# Verify sliced weight shapes
self.assertEqual(sliced_impl.float8_data.shape, (16, 32))

# Get original quantized weight implementation for scale comparison
original_quant_impl = quant_model.weight

# Verify scale properties based on granularity
if isinstance(granularity, PerTensor):
# Per-tensor: scale should be identical to original (scalar)
self.assertEqual(sliced_impl.scale.numel(), 1)
self.assertTrue(torch.equal(sliced_impl.scale, original_quant_impl.scale))
else: # PerRow
# Per-row: scale should be sliced to match the selected rows (0:16)
expected_scale_shape = (16, 1)
self.assertEqual(sliced_impl.scale.shape, expected_scale_shape)
# Verify the scale values are the correct slice from the original
self.assertTrue(
torch.equal(sliced_impl.scale, original_quant_impl.scale[0:16])
)

# Verify that sliced quantized data matches the correct slice from original
original_float8_data_slice = quant_model.weight.float8_data[0:16, 0:32]
self.assertTrue(
torch.equal(sliced_impl.float8_data, original_float8_data_slice)
)

# Verify that sliced weights can be converted back to float with correct values
sliced_float_weight = quant_weight_slice.to(dtype)
self.assertEqual(sliced_float_weight.shape, (16, 32))
self.assertEqual(sliced_float_weight.dtype, dtype)

input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight

# Compute with sliced weights
with torch.no_grad():
ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice)
quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice)

# Verify shapes
expected_shape = (8, 16) # batch_size x out_features_sliced
self.assertEqual(ref_output.shape, expected_shape)
self.assertEqual(quant_output.shape, expected_shape)

# Verify reasonable quantization error
error = compute_error(ref_output, quant_output)
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")

def test_preprocess_scale_3d_reshape(self):
"""Test that preprocess_scale correctly handles 3D scale tensors"""
device = "cpu" # Use CPU for basic functionality test
Expand Down
Loading
Loading