diff --git a/test/dtypes/test_fbgemm_fp8.py b/test/dtypes/test_fbgemm_fp8.py index 1e681d00f9..ea869a1c39 100644 --- a/test/dtypes/test_fbgemm_fp8.py +++ b/test/dtypes/test_fbgemm_fp8.py @@ -128,6 +128,8 @@ def forward(self, x): weight = torch.randn(10, 128, 256, dtype=dtype, device=device) m = M(weight).eval() original = m(input) + # we need to transpose the weight first for bmm + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 20) diff --git a/test/dtypes/test_fbgemm_int4.py b/test/dtypes/test_fbgemm_int4.py index cba9d81ae0..eb1f059775 100644 --- a/test/dtypes/test_fbgemm_int4.py +++ b/test/dtypes/test_fbgemm_int4.py @@ -39,7 +39,6 @@ def setUp(self): weight_dtype=torch.int4, output_dtype=torch.bfloat16, block_size=[1, 1, 128], - transpose_input=True, ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] @@ -134,6 +133,8 @@ def forward(self, x): weight = torch.randn(10, 128, 256, dtype=dtype, device=device) m = M(weight).eval() original = m(input) + # we need to transpose the weight first for bmm + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py index b6c1d72acc..85f83bcb50 100644 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ b/torchao/dtypes/fbgemm_fp8_tensor.py @@ -90,7 +90,6 @@ def from_float( cls, w: torch.Tensor, activation_scale_ub: Optional[float] = None, - transpose_input: bool = False, ): if activation_scale_ub is None: activation_scale_ub = 1200.0 @@ -100,12 +99,6 @@ def from_float( dtype=torch.float, device=w.device, ) - if transpose_input: - if w.ndim == 3: - w = w.transpose(-1, -2) - else: - w = w.t() - wq, w_scale = torch.ops.triton.quantize_fp8_row(w) # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) dtype = w.dtype diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index 0c00ee1a81..385f70e3bb 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -93,7 +93,6 @@ def from_float( cls, w: torch.Tensor, block_size: List[int], - transpose_input: bool = False, ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" @@ -101,12 +100,6 @@ def from_float( if int4_row_quantize_zp is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - if transpose_input: - if w.ndim == 3: - w = w.transpose(-1, -2) - else: - w = w.t() - group_size = block_size[-1] original_shape = w.shape diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4e2cdb8843..7df6995955 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2047,7 +2047,6 @@ class FbgemmConfig(AOBaseConfig): output_dtype: torch.dtype block_size: Optional[List[int]] = None activation_scale_ub: Optional[float] = None - transpose_input: bool = False preshuffle: bool = False @@ -2074,7 +2073,6 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: weight = to_fbgemm_int4( module.weight, config.block_size, - config.transpose_input, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) @@ -2087,7 +2085,6 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: weight = to_fbgemm_fp8( module.weight, config.activation_scale_ub, - config.transpose_input, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module)