From 199e6123d56d03b376c4aa483a0a51f938b1bac4 Mon Sep 17 00:00:00 2001
From: Sergii Dymchenko <kit1980@gmail.com>
Date: Mon, 27 Jan 2025 16:51:16 -0800
Subject: [PATCH] Use log1p(x) instead of log(1+x) (#1401)

This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html

Found with TorchFix https://github.com/pytorch-labs/torchfix/

Signed-off-by: Sergii Dymchenko <sdym@meta.com>
Co-authored-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
---
 transformer_engine/pytorch/attention.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py
index f2120f3a73..ccceacff85 100644
--- a/transformer_engine/pytorch/attention.py
+++ b/transformer_engine/pytorch/attention.py
@@ -1604,7 +1604,7 @@ def flash_attn_fwd_softmax_lse_correction(
     """Merge softmax stats of each step in Attention with context parallelism"""
     max_scale = torch.max(softmax_lse, softmax_lse_per_step)
     min_scale = torch.min(softmax_lse, softmax_lse_per_step)
-    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
+    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
     softmax_lse.copy_(new_scale)