diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 15099dc2c1..df86c6f04e 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -34,9 +34,7 @@ e5m2_dtype, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, -) +from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_ops import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, @@ -379,12 +377,16 @@ def test_linear_from_config_params( ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) + @pytest.mark.parametrize( + "linear_dtype", [torch.bfloat16, torch.float16, torch.float32] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skip_if_rocm("ROCm enablement in progress") def test_linear_from_recipe( self, recipe_name, x_shape, + linear_dtype: torch.dtype, linear_bias: bool, ): if torch.cuda.get_device_capability() < (9, 0): @@ -393,7 +395,6 @@ def test_linear_from_recipe( ) pytest.skip() - linear_dtype = torch.bfloat16 x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) config = Float8LinearConfig.from_recipe_name(recipe_name) diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 4071d83e4f..7e5432c6c5 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -54,6 +54,12 @@ def addmm_float8_unwrapped( a_inverse_scale = a_inverse_scale.new_ones(()) b_inverse_scale = a_inverse_scale.new_ones(()) + # work around torch._scaled_mm not having float32 output type + # TODO(pytorch/pytorch#156771): remove this once torch._scaled_mm supports float32 output + orig_dtype = output_dtype + if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling: + output_dtype = torch.bfloat16 + post_bias = None if output_dtype == torch.float32: # Bias is not supported by _scaled_mm when output is fp32 @@ -76,6 +82,9 @@ def addmm_float8_unwrapped( if post_bias is not None: output += post_bias + if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling: + output = output.to(orig_dtype) + return output