From 199e6123d56d03b376c4aa483a0a51f938b1bac4 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko <kit1980@gmail.com> Date: Mon, 27 Jan 2025 16:51:16 -0800 Subject: [PATCH] Use log1p(x) instead of log(1+x) (#1401) This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html Found with TorchFix https://github.com/pytorch-labs/torchfix/ Signed-off-by: Sergii Dymchenko <sdym@meta.com> Co-authored-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f2120f3a73..ccceacff85 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1604,7 +1604,7 @@ def flash_attn_fwd_softmax_lse_correction( """Merge softmax stats of each step in Attention with context parallelism""" max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) softmax_lse.copy_(new_scale)