We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cab9659 commit 2cd1727Copy full SHA for 2cd1727
transformer_engine/jax/cpp_extensions/softmax.py
@@ -853,6 +853,7 @@ def jax_general_softmax(
853
result = jnp.where(where, result, 0)
854
return result
855
856
+
857
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
858
"""
859
scaled_softmax_forward wrapper
0 commit comments