Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove set fp32 math mode & increase tolerance #1860

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/tutorials/10-experimental-block-pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ def matmul(a, b, accum_dtype, res_dtype):
# Still we can test our matrix multiplication with block pointers against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)
torch.xpu.set_fp32_math_mode(torch.xpu.utils.FP32MathMode.TF32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work also when using upstream PyTorch. I do not think we should be fixing the tutorial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point of @guangyey is that to use this workaround until this feature is implemented in the upstream

Copy link

@guangyey guangyey Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @etiotto
set_fp32_math_mode is not yet upstreamed to stock PyTorch. And we need time to redesign this API according to other backends. So I personally recommend @ZzEeKkAa to do this workaround until we complete this API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tutorial is not currently present upstream

Copy link
Contributor

@etiotto etiotto Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @etiotto set_fp32_math_mode is not yet upstreamed to stock PyTorch. And we need time to redesign this API according to other backends. So I personally recommend @ZzEeKkAa to do this workaround until we complete this API.

OK. We can put this in to unblock the work of migrating to use PyTorch (instead of IPEX). @ZzEeKkAa please add a FIXME in the code and open an issue so that once we have support in PyTorch for set_fp32_math_mode we can go back and revert this change. Once that is done I will be able to approve the PR. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going by Julian's comment, we might also want to add a ticket to support true fp32 matmul. This will of course have the downside of not using DPAS so will be slow by default (not ideal, I know).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZzEeKkAa Please create the 2 issues as described above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping. @ZzEeKkAa are the 2 issues mentioned opened ? Links?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@whitneywhtsang @etiotto I've just opened the issues and updated the PR with FIXME comment.
#1956
#1957

for dtype, accum_dtype, res_dtype in [(torch.float16, torch.float16, torch.float16),
(torch.float16, torch.float32, torch.float16),
(torch.float16, torch.float32, torch.float32),
Expand Down Expand Up @@ -373,7 +372,9 @@ def matmul(a, b, accum_dtype, res_dtype):
# Note: the torch.matmul and Triton implementations uses different
# algorithms so we need to adjust tolerance.
rtol = 1e-2 if dtype == torch.bfloat16 or accum_dtype in [torch.float16, torch.bfloat16] else 1e-3
atol = 1e-2 if accum_dtype == torch.bfloat16 else 1e-3 if accum_dtype == torch.float16 else 1e-4
# FIXME: Remove 1e-1 tolerance for fp32, once fp32 math mode is implemented at pytorch:
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1957
atol = 1e-1 if dtype == torch.float32 else 1e-2 if accum_dtype == torch.bfloat16 else 1e-3 if accum_dtype == torch.float16 else 1e-4
FMarno marked this conversation as resolved.
Show resolved Hide resolved
if torch.allclose(triton_output, torch_output, atol=atol, rtol=rtol):
print("✅ Triton and Torch match")
else:
Expand Down