Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashAttention tutorial requires relaxed verification in advanced path (perf_attn) #2098

Open
victor-eds opened this issue Sep 3, 2024 · 10 comments

Comments

@victor-eds
Copy link
Contributor

victor-eds commented Sep 3, 2024

Comparing Triton vs XeTLA FlashAttention output in FlashAttention using atol=1e-2, rtol=0 as in upstream leads to size 1 32 16384 64 missing verification. A more relaxed atol=1e-1 value verifies, but this might be a bit too permissive taking into account values will be less than 1 anyway (FlashAttention is a SoftMax).

In order to reproduce, add the following code to the forward function, right before the return:

torch_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False).to(torch.float32)
torch.testing.assert_close(o, torch_output, atol=1e-2, rtol=0)
@victor-eds victor-eds changed the title FlashAttention tutorial requires relaxed verification in advanced path FlashAttention tutorial requires relaxed verification in advanced path (perf_attn) Sep 3, 2024
@vlad-penkin vlad-penkin added this to the 4.0 [Performance] Core milestone Sep 5, 2024
@Dewei-Wang-sh
Copy link
Contributor

some different IR from poc, need to check they are the same.

@quintinwang5
Copy link
Contributor

Actually, this is not a problem on Triton side.
image
In the computation of attention, output = PV, and P is a softmax value according the last dimension(in naive 2D case, a row). So, if we set V to all ones use torch.ones. The output should be all ones(or very close to 1). For 1 32 16384 64, triton is right, but torch is not.

triton_output tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]]],
       device='xpu:0')
===========================================================
torch tensor([[[[0.8726, 0.8726, 0.8726,  ..., 0.8726, 0.8726, 0.8726],
          [0.8726, 0.8726, 0.8726,  ..., 0.8726, 0.8726, 0.8726],
          [0.8706, 0.8706, 0.8706,  ..., 0.8706, 0.8706, 0.8706],
          ...,
         [0.8716, 0.8716, 0.8716,  ..., 0.8716, 0.8716, 0.8716],
          [0.8706, 0.8706, 0.8706,  ..., 0.8706, 0.8706, 0.8706],
          [0.8711, 0.8711, 0.8711,  ..., 0.8711, 0.8711, 0.8711]]]],
       device='xpu:0')

For 4 48 1024 64, it's both OK.

triton_output tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]]],
       device='xpu:0')
===========================================================
torch tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.9995, 0.9995, 0.9995,  ..., 0.9995, 0.9995, 0.9995],
          [0.9995, 0.9995, 0.9995,  ..., 0.9995, 0.9995, 0.9995],

@victor-eds
Copy link
Contributor Author

I see for that particular corner case we are more precise (even in the second example). However, could we test this for random inputs, e.g., comparing with XeTLA and other vendors like CUDA or CPU? If this is indeed an XeTLA issue, we could report to them.

@quintinwang5
Copy link
Contributor

I've verified that Triton's result can match CUDA's using atol=1e-2, rtol=0. So it's clear that Pytorch gives a wrong result for this case. I'll close this issue and file a new issue to our Pytorch team.
Notice: If you want to verified the result, be careful with the different behaviors of torch.manual_seed and torch.randn between CUDA and XPU. Although we choose the same seed when calling 3 randn for q, k, v, we can get the same q, but different k, v. I just save all q, k, v, then load the same copy to avoid this problem.

@victor-eds
Copy link
Contributor Author

I've verified that Triton's result can match CUDA's using atol=1e-2, rtol=0. So it's clear that Pytorch gives a wrong result for this case. I'll close this issue and file a new issue to our Pytorch team. Notice: If you want to verified the result, be careful with the different behaviors of torch.manual_seed and torch.randn between CUDA and XPU. Although we choose the same seed when calling 3 randn for q, k, v, we can get the same q, but different k, v. I just save all q, k, v, then load the same copy to avoid this problem.

Thanks for the investigation! Good findings!

@quintinwang5
Copy link
Contributor

Track: pytorch/pytorch#135085

@vlad-penkin
Copy link
Contributor

Let's revisit this issue, pytorch/pytorch#135085 is closed.

@vlad-penkin vlad-penkin reopened this Dec 18, 2024
@quintinwang5
Copy link
Contributor

Still OOM on 1100, and take too much time on 1550 (not sure it's hung or a slow execution). Used torch=2.6.0a0+git61dc5e9 is the latest on from CI build.

@quintinwang5
Copy link
Contributor

Confirmed SDPA for XPU is a feature targeted Pytorch 2.7.

@quintinwang5
Copy link
Contributor

Tracking in pytorch/pytorch#140389

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants