Skip to content

Commit

Permalink
Use log1p(x) instead of log(1+x) (#1401)
Browse files Browse the repository at this point in the history
This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html

Found with TorchFix https://github.com/pytorch-labs/torchfix/

Signed-off-by: Sergii Dymchenko <[email protected]>
Co-authored-by: Xiaowei Ren <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
3 people authored Jan 28, 2025
1 parent 2fce82b commit 199e612
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ def flash_attn_fwd_softmax_lse_correction(
"""Merge softmax stats of each step in Attention with context parallelism"""
max_scale = torch.max(softmax_lse, softmax_lse_per_step)
min_scale = torch.min(softmax_lse, softmax_lse_per_step)
new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
softmax_lse.copy_(new_scale)


Expand Down

0 comments on commit 199e612

Please sign in to comment.