Open
Description
Comparing Triton vs XeTLA FlashAttention output in FlashAttention using atol=1e-2, rtol=0
as in upstream leads to size 1 32 16384 64
missing verification. A more relaxed atol=1e-1
value verifies, but this might be a bit too permissive taking into account values will be less than 1 anyway (FlashAttention is a SoftMax).
In order to reproduce, add the following code to the forward
function, right before the return
:
torch_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False).to(torch.float32)
torch.testing.assert_close(o, torch_output, atol=1e-2, rtol=0)