Skip to content

test rowwise fp32 #2431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
e5m2_dtype,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
)
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_ops import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import (
get_maybe_axiswise_dim,
Expand Down Expand Up @@ -379,12 +377,16 @@ def test_linear_from_config_params(
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@pytest.mark.parametrize(
"linear_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skip_if_rocm("ROCm enablement in progress")
def test_linear_from_recipe(
self,
recipe_name,
x_shape,
linear_dtype: torch.dtype,
linear_bias: bool,
):
if torch.cuda.get_device_capability() < (9, 0):
Expand All @@ -393,7 +395,6 @@ def test_linear_from_recipe(
)
pytest.skip()

linear_dtype = torch.bfloat16
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
config = Float8LinearConfig.from_recipe_name(recipe_name)
Expand Down
9 changes: 9 additions & 0 deletions torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def addmm_float8_unwrapped(
a_inverse_scale = a_inverse_scale.new_ones(())
b_inverse_scale = a_inverse_scale.new_ones(())

# work around torch._scaled_mm not having float32 output type
# TODO(pytorch/pytorch#156771): remove this once torch._scaled_mm supports float32 output
orig_dtype = output_dtype
if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
output_dtype = torch.bfloat16

Copy link
Contributor

@vkuzo vkuzo Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding a flag, TBH I think we can just enable this on-by-default, like this:

file issue in PyTorch core to add float32 output to scaled_mm

output_dtype_to_use = output_dtype
if is_rowwise_scaling:
    # work around torch._scaled_mm not having float32 output type
    # TODO(issue number): remove this once torch._scaled_mm supports float32 output
    output_dtype_to_use = torch.bfloat16
output = torch._scaled_mm(..., output_dtype_to_use, ...)
...
if is_rowwise_scaling and output_dtype == torch.float32:
    # work around torch._scaled_mm not having float32 output type
    # TODO(issue number): remove this once torch._scaled_mm supports float32 output
    output = output.to(orig_dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, I'll change to enable by default and file an issue.

post_bias = None
if output_dtype == torch.float32:
# Bias is not supported by _scaled_mm when output is fp32
Expand All @@ -76,6 +82,9 @@ def addmm_float8_unwrapped(
if post_bias is not None:
output += post_bias

if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
output = output.to(orig_dtype)

return output


Expand Down
Loading