Skip to content

Commit 2cd1727

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent cab9659 commit 2cd1727

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

transformer_engine/jax/cpp_extensions/softmax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ def jax_general_softmax(
853853
result = jnp.where(where, result, 0)
854854
return result
855855

856+
856857
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
857858
"""
858859
scaled_softmax_forward wrapper

0 commit comments

Comments
 (0)